@@ -347,6 +347,43 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
347347 R
348348end
349349
350+ # # Base interface
351+
352+ Base. _accumulate!(op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Nothing ) =
353+ accumulate!(op, typed_data(output), typed_data(input); dims= 1 )
354+
355+ Base. _accumulate!(op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Nothing ) =
356+ accumulate!(op, typed_data(output), typed_data(input); dims)
357+
358+ Base. _accumulate!(op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Some ) =
359+ accumulate!(op, typed_data(output), typed_data(input); dims= 1 , init= something(init))
360+
361+ Base. _accumulate!(op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Some ) =
362+ accumulate!(op, typed_data(output), typed_data(input); dims, init= something(init))
363+
364+ Base. accumulate_pairwise!(op, result:: AnyJLVector , v:: AnyJLVector ) = accumulate!(op, result, v)
365+
366+ # default behavior unless dims are specified by the user
367+ function Base. accumulate(op, A:: AnyJLArray ;
368+ dims:: Union{Nothing,Integer} = nothing , kw... )
369+ nt = values(kw)
370+ if dims === nothing && ! (A isa AbstractVector)
371+ # This branch takes care of the cases not handled by `_accumulate!`.
372+ return reshape(accumulate(op, typed_data(A)[:]; kw... ), size(A))
373+ end
374+ if isempty(kw)
375+ out = similar(A, Base. promote_op(op, eltype(A), eltype(A)))
376+ init = AK. neutral_element(op, eltype(out))
377+ elseif keys(nt) === (:init,)
378+ out = similar(A, Base. promote_op(op, typeof(nt. init), eltype(A)))
379+ init = nt. init
380+ else
381+ throw(ArgumentError(" accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,))) " ))
382+ end
383+ accumulate!(op, typed_data(out), typed_data(A); dims, init)
384+ end
385+
386+
350387# # KernelAbstractions interface
351388
352389KernelAbstractions. get_backend(a:: JLA ) where JLA <: JLArray = JLBackend()
0 commit comments