Skip to content

Commit caf2996

Browse files
Minor findall and accumulate tests improvements (#592)
1 parent e1cb12e commit caf2996

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/indexing.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ end
2323

2424
function Base.findall(bools::WrappedMtlArray{Bool})
2525
I = keytype(bools)
26-
indices = cumsum(reshape(bools, prod(size(bools))))
26+
boolslen = prod(size(bools))
2727

28-
n = @allowscalar indices[end]
28+
indices = MtlVector{Int64, Metal.SharedStorage}(undef, boolslen)
29+
cumsum!(indices, reshape(bools, boolslen))
30+
31+
n = indices[end]
2932
ys = similar(bools, I, n)
3033

3134
if n > 0

test/array.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ end
490490
end
491491

492492
@testset "accumulate" begin
493-
testf(f, x) = Array(f(MtlArray(x))) f(x)
494493
for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not
495494
@test testf(x->accumulate(+, x), rand(Float32, n))
496495
@test testf(x->accumulate(+, x), rand(Float32, n, 2))
@@ -500,17 +499,17 @@ end
500499
# multidimensional
501500
for (sizes, dims) in ((2,) => 2,
502501
(3,4,5) => 2,
503-
(1, 70, 50, 20) => 3)
504-
@test testf(x->accumulate(+, x; dims=dims), rand(Int, sizes))
505-
@test testf(x->accumulate(+, x), rand(Int, sizes))
502+
(1, 70, 50, 20) => 3,)
503+
@test testf(x->accumulate(+, x; dims=dims), rand(-10:10, sizes))
504+
@test testf(x->accumulate(+, x), rand(-10:10, sizes))
506505
end
507506

508507
# using initializer
509508
for (sizes, dims) in ((2,) => 2,
510509
(3,4,5) => 2,
511510
(1, 70, 50, 20) => 3)
512-
@test testf(Base.Fix2((x,y)->accumulate(+, x; dims=dims, init=y), rand(Int)), rand(Int, sizes))
513-
@test testf(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(Int)), rand(Int, sizes))
511+
@test testf(Base.Fix2((x,y)->accumulate(+, x; dims=dims, init=y), rand(-10:10)), rand(-10:10, sizes))
512+
@test testf(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(-10:10)), rand(-10:10, sizes))
514513
end
515514

516515
# in place

0 commit comments

Comments
 (0)