@@ -347,6 +347,43 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
347
347
R
348
348
end
349
349
350
+ # # Base interface
351
+
352
+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Nothing ) =
353
+ accumulate! (op, typed_data (output), typed_data (input); dims= 1 )
354
+
355
+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Nothing ) =
356
+ accumulate! (op, typed_data (output), typed_data (input); dims)
357
+
358
+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLVector , dims:: Nothing , init:: Some ) =
359
+ accumulate! (op, typed_data (output), typed_data (input); dims= 1 , init= something (init))
360
+
361
+ Base. _accumulate! (op, output:: AnyJLArray , input:: AnyJLArray , dims:: Integer , init:: Some ) =
362
+ accumulate! (op, typed_data (output), typed_data (input); dims, init= something (init))
363
+
364
+ Base. accumulate_pairwise! (op, result:: AnyJLVector , v:: AnyJLVector ) = accumulate! (op, result, v)
365
+
366
+ # default behavior unless dims are specified by the user
367
+ function Base. accumulate (op, A:: AnyJLArray ;
368
+ dims:: Union{Nothing,Integer} = nothing , kw... )
369
+ nt = values (kw)
370
+ if dims === nothing && ! (A isa AbstractVector)
371
+ # This branch takes care of the cases not handled by `_accumulate!`.
372
+ return reshape (accumulate (op, typed_data (A)[:]; kw... ), size (A))
373
+ end
374
+ if isempty (kw)
375
+ out = similar (A, Base. promote_op (op, eltype (A), eltype (A)))
376
+ init = AK. neutral_element (op, eltype (out))
377
+ elseif keys (nt) === (:init ,)
378
+ out = similar (A, Base. promote_op (op, typeof (nt. init), eltype (A)))
379
+ init = nt. init
380
+ else
381
+ throw (ArgumentError (" accumulate does not support the keyword arguments $(setdiff (keys (nt), (:init ,))) " ))
382
+ end
383
+ accumulate! (op, typed_data (out), typed_data (A); dims, init)
384
+ end
385
+
386
+
350
387
# # KernelAbstractions interface
351
388
352
389
KernelAbstractions. get_backend (a:: JLA ) where JLA <: JLArray = JLBackend ()
0 commit comments