Skip to content

Commit 9edcffa

Browse files
committed
Port over accumulations
Also tests
1 parent 15ea972 commit 9edcffa

File tree

4 files changed

+145
-0
lines changed

4 files changed

+145
-0
lines changed

src/GPUArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ include("host/construction.jl")
2828
include("host/base.jl")
2929
include("host/indexing.jl")
3030
include("host/broadcast.jl")
31+
include("host/accumulate.jl")
3132
include("host/mapreduce.jl")
3233
include("host/linalg.jl")
3334
include("host/math.jl")

src/host/accumulate.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
## Base interface
2+
3+
Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Nothing) =
4+
AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))
5+
6+
Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Nothing) =
7+
AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))
8+
9+
Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Some) =
10+
AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))
11+
12+
Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Some) =
13+
AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))
14+
15+
Base.accumulate_pairwise!(op, result::AnyGPUVector, v::AnyGPUVector) = accumulate!(op, result, v)
16+
17+
# default behavior unless dims are specified by the user
18+
function Base.accumulate(op, A::AnyGPUArray;
19+
dims::Union{Nothing,Integer}=nothing, kw...)
20+
nt = values(kw)
21+
if dims === nothing && !(A isa AbstractVector)
22+
# This branch takes care of the cases not handled by `_accumulate!`.
23+
return reshape(AK.accumulate(op, A[:], get_backend(A); init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A)))), size(A))
24+
end
25+
if isempty(kw)
26+
out = similar(A, Base.promote_op(op, eltype(A), eltype(A)))
27+
init = AK.neutral_element(op, eltype(out))
28+
elseif keys(nt) === (:init,)
29+
out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A)))
30+
init = nt.init
31+
else
32+
throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))"))
33+
end
34+
AK.accumulate!(op, out, A, get_backend(A); dims, init)
35+
end

test/testsuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ include("testsuite/indexing.jl")
9090
include("testsuite/base.jl")
9191
include("testsuite/vector.jl")
9292
include("testsuite/reductions.jl")
93+
include("testsuite/accumulations.jl")
9394
include("testsuite/broadcasting.jl")
9495
include("testsuite/linalg.jl")
9596
include("testsuite/math.jl")

test/testsuite/accumulations.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
@testsuite "accumulations" (AT, eltypes)->begin
2+
@testset "$ET" for ET in eltypes
3+
range = ET <: Real ? (ET(1):ET(10)) : ET
4+
5+
# 1d arrays
6+
for num_elems in 1:256
7+
@test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, num_elems))
8+
end
9+
10+
for num_elems = rand(1:100, 10)
11+
@test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, num_elems))
12+
end
13+
14+
for _ in 1:10 # nd arrays reduced as 1d
15+
n1 = rand(1:10)
16+
n2 = rand(1:10)
17+
n3 = rand(1:10)
18+
@test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, n1, n2, n3))
19+
end
20+
21+
for num_elems = rand(1:100, 10) # init value
22+
init = rand(range)
23+
@test compare(A->accumulate(+, A; init), AT, rand(range, num_elems))
24+
end
25+
26+
27+
# nd arrays
28+
for dims in 1:4 # corner cases
29+
for isize in 1:3
30+
for jsize in 1:3
31+
for ksize in 1:3
32+
@test compare(A->accumulate(+, A; dims, init=zero(ET)), AT, rand(range, isize, jsize, ksize))
33+
end
34+
end
35+
end
36+
end
37+
38+
for _ in 1:10
39+
for dims in 1:3
40+
n1 = rand(1:10)
41+
n2 = rand(1:10)
42+
n3 = rand(1:10)
43+
@test compare(A->accumulate(+, A; dims, init=zero(ET)), AT, rand(range, n1, n2, n3))
44+
end
45+
end
46+
47+
for _ in 1:10 # init value
48+
for dims in 1:3
49+
n1 = rand(1:10)
50+
n2 = rand(1:10)
51+
n3 = rand(1:10)
52+
init = rand(range)
53+
@test compare(A->accumulate(+, A; init, dims), AT, rand(range, n1, n2, n3))
54+
end
55+
end
56+
57+
# Larger containers to try and detect weird bugs
58+
for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not
59+
# Skip large tests on small datatypes
60+
n >= 10000 && sizeof(real(ET)) <= 2 && continue
61+
62+
@test compare(x->accumulate(+, x), AT, rand(range, n))
63+
@test compare(x->accumulate(+, x), AT, rand(range, n, 2))
64+
@test compare(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(range)), AT, rand(range, n))
65+
end
66+
67+
# in place
68+
@test compare(x->(accumulate!(+, x, copy(x)); x), AT, rand(range, 2))
69+
70+
@test_throws ArgumentError("accumulate does not support the keyword arguments [:bad_kwarg]") accumulate(+, AT(rand(ET, 10)); bad_kwarg="bad")
71+
end
72+
end
73+
74+
@testsuite "accumulations/cumsum & cumprod" (AT, eltypes)->begin
75+
@test compare(cumsum, AT, rand(Bool, 16))
76+
77+
@testset "$ET" for ET in eltypes
78+
range = ET <: Real ? (ET(1):ET(10)) : ET
79+
80+
# cumsum
81+
for num_elems in rand(1:100, 10)
82+
@test compare(A->cumsum(A; dims=1), AT, rand(range, num_elems))
83+
end
84+
85+
for _ in 1:10
86+
for dims in 1:3
87+
n1 = rand(1:10)
88+
n2 = rand(1:10)
89+
n3 = rand(1:10)
90+
@test compare(A->cumsum(A; dims), AT, rand(range, n1, n2, n3))
91+
end
92+
end
93+
94+
95+
# cumprod
96+
range = ET <: Real ? (ET(1):ET(10)) : ET
97+
@test compare(A->cumprod(A; dims=1), AT, ones(ET, 100_000))
98+
99+
for _ in 1:10
100+
for dims in 1:3
101+
n1 = rand(1:10)
102+
n2 = rand(1:10)
103+
n3 = rand(1:10)
104+
@test compare(A->cumprod(A; dims), AT, rand(range, n1, n2, n3))
105+
end
106+
end
107+
end
108+
end

0 commit comments

Comments
 (0)