|
| 1 | +## Base interface |
| 2 | + |
| 3 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Nothing) = |
| 4 | + AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output))) |
| 5 | + |
| 6 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Nothing) = |
| 7 | + AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output))) |
| 8 | + |
| 9 | +Base._accumulate!(op, output::AnyGPUArray, input::MtlVector, dims::Nothing, init::Some) = |
| 10 | + AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init)) |
| 11 | + |
| 12 | +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Some) = |
| 13 | + AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init)) |
| 14 | + |
| 15 | +Base.accumulate_pairwise!(op, result::AnyGPUVector, v::AnyGPUVector) = accumulate!(op, result, v) |
| 16 | + |
| 17 | +# default behavior unless dims are specified by the user |
| 18 | +function Base.accumulate(op, A::WrappedGPUArray; |
| 19 | + dims::Union{Nothing,Integer}=nothing, kw...) |
| 20 | + nt = values(kw) |
| 21 | + if dims === nothing && !(A isa AbstractVector) |
| 22 | + # This branch takes care of the cases not handled by `_accumulate!`. |
| 23 | + return reshape(AK.accumulate(op, A[:], get_backend(A); init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A)))), size(A)) |
| 24 | + end |
| 25 | + if isempty(kw) |
| 26 | + out = similar(A, Base.promote_op(op, eltype(A), eltype(A))) |
| 27 | + init = AK.neutral_element(op, eltype(out)) |
| 28 | + elseif keys(nt) === (:init,) |
| 29 | + out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A))) |
| 30 | + init = nt.init |
| 31 | + else |
| 32 | + throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))")) |
| 33 | + end |
| 34 | + AK.accumulate!(op, out, A, get_backend(A); dims, init) |
| 35 | +end |
0 commit comments