Skip to content

Commit 5fa011f

Browse files
committed
Fix JLArrays until KA 0.10
1 parent c9c6256 commit 5fa011f

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

lib/JLArrays/src/JLArrays.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,43 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
347347
R
348348
end
349349

350+
## Base interface
351+
352+
Base._accumulate!(op, output::AnyJLArray, input::AnyJLVector, dims::Nothing, init::Nothing) =
353+
accumulate!(op, typed_data(output), typed_data(input); dims=1)
354+
355+
Base._accumulate!(op, output::AnyJLArray, input::AnyJLArray, dims::Integer, init::Nothing) =
356+
accumulate!(op, typed_data(output), typed_data(input); dims)
357+
358+
Base._accumulate!(op, output::AnyJLArray, input::AnyJLVector, dims::Nothing, init::Some) =
359+
accumulate!(op, typed_data(output), typed_data(input); dims=1, init=something(init))
360+
361+
Base._accumulate!(op, output::AnyJLArray, input::AnyJLArray, dims::Integer, init::Some) =
362+
accumulate!(op, typed_data(output), typed_data(input); dims, init=something(init))
363+
364+
Base.accumulate_pairwise!(op, result::AnyJLVector, v::AnyJLVector) = accumulate!(op, result, v)
365+
366+
# default behavior unless dims are specified by the user
367+
function Base.accumulate(op, A::AnyJLArray;
368+
dims::Union{Nothing,Integer}=nothing, kw...)
369+
nt = values(kw)
370+
if dims === nothing && !(A isa AbstractVector)
371+
# This branch takes care of the cases not handled by `_accumulate!`.
372+
return reshape(accumulate(op, typed_data(A)[:]; kw...), size(A))
373+
end
374+
if isempty(kw)
375+
out = similar(A, Base.promote_op(op, eltype(A), eltype(A)))
376+
init = AK.neutral_element(op, eltype(out))
377+
elseif keys(nt) === (:init,)
378+
out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A)))
379+
init = nt.init
380+
else
381+
throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))"))
382+
end
383+
accumulate!(op, typed_data(out), typed_data(A); dims, init)
384+
end
385+
386+
350387
## KernelAbstractions interface
351388

352389
KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
15+
16+
[sources]
17+
JLArrays = {path = "/Users/christian/.julia/dev/GPUArrays/lib/JLArrays"}

0 commit comments

Comments
 (0)