Skip to content

Commit c37267a

Browse files
committed
Port over accumulations
1 parent 023896d commit c37267a

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

src/host/accumulate.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
## Base interface
2+
3+
Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Nothing) =
4+
AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))
5+
6+
Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Nothing) =
7+
AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output)))
8+
9+
Base._accumulate!(op, output::AnyGPUArray, input::MtlVector, dims::Nothing, init::Some) =
10+
AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))
11+
12+
Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Some) =
13+
AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init))
14+
15+
Base.accumulate_pairwise!(op, result::AnyGPUVector, v::AnyGPUVector) = accumulate!(op, result, v)
16+
17+
# default behavior unless dims are specified by the user
18+
function Base.accumulate(op, A::WrappedGPUArray;
19+
dims::Union{Nothing,Integer}=nothing, kw...)
20+
nt = values(kw)
21+
if dims === nothing && !(A isa AbstractVector)
22+
# This branch takes care of the cases not handled by `_accumulate!`.
23+
return reshape(AK.accumulate(op, A[:], get_backend(A); init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A)))), size(A))
24+
end
25+
if isempty(kw)
26+
out = similar(A, Base.promote_op(op, eltype(A), eltype(A)))
27+
init = AK.neutral_element(op, eltype(out))
28+
elseif keys(nt) === (:init,)
29+
out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A)))
30+
init = nt.init
31+
else
32+
throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))"))
33+
end
34+
AK.accumulate!(op, out, A, get_backend(A); dims, init)
35+
end

test/testsuite.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ include("testsuite/random.jl")
9696
include("testsuite/uniformscaling.jl")
9797
include("testsuite/statistics.jl")
9898
include("testsuite/alloc_cache.jl")
99+
include("testsuite/accumulations.jl")
99100
include("testsuite/jld2ext.jl")
100101

101102
"""

test/testsuite/accumulations.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# @testsuite "accumulations" (AT, eltypes)->begin
21
@testsuite "accumulations" (AT, eltypes)->begin
32
@testset "$ET" for ET in eltypes
43
range = ET <: Real ? (ET(1):ET(10)) : ET
@@ -54,6 +53,13 @@
5453
@test compare(A->accumulate(+, A; init, dims), AT, rand(range, n1, n2, n3))
5554
end
5655
end
56+
57+
# Larger containers to try and detect weird bugs
58+
for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not
59+
@test compare(x->accumulate(+, x), AT, rand(range, n))
60+
@test compare(x->accumulate(+, x), AT, rand(range, n, 2))
61+
@test compare(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(range)), AT, rand(range, n))
62+
end
5763
end
5864
end
5965

0 commit comments

Comments
 (0)