Skip to content

Commit 30432ce

Browse files
Fix findall output type (#587)
1 parent 987264c commit 30432ce

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function Base.findall(bools::WrappedMtlArray{Bool})
2626
indices = cumsum(reshape(bools, prod(size(bools))))
2727

2828
n = @allowscalar indices[end]
29-
ys = MtlArray{I}(undef, n)
29+
ys = similar(bools, I, n)
3030

3131
if n > 0
3232
function kernel(ys::MtlDeviceArray, bools, indices)

test/array.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,13 @@ end
526526
@test testf(x->findall(x), rand(Bool, 1000))
527527
@test testf(x->findall(y->y>Float32(0.5), x), rand(Float32,1000))
528528

529+
# Set storage mode to a different one than the default
530+
let storage=Metal.DefaultStorageMode == Metal.PrivateStorage ? Metal.SharedStorage : Metal.PrivateStorage
531+
x = mtl(rand(Float32,100); storage)
532+
out = findall(y->y>Float32(0.5), x)
533+
@test Metal.storagemode(x) == Metal.storagemode(out)
534+
end
535+
529536
# ND
530537
let x = rand(Bool, 1000, 1000)
531538
@test findall(x) == Array(findall(MtlArray(x)))

0 commit comments

Comments
 (0)