diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 18505d55..95210991 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -48,7 +48,7 @@ steps: - label: "oneAPI.jl" plugins: - JuliaCI/julia#v1: - version: "1.10" + version: "1.11" - JuliaCI/julia-coverage#v1: codecov: true command: | @@ -95,7 +95,7 @@ steps: - label: "Metal.jl" plugins: - JuliaCI/julia#v1: - version: "1.10" + version: "1.11" - JuliaCI/julia-coverage#v1: codecov: true command: | diff --git a/Project.toml b/Project.toml index 0051650a..d52e498d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" version = "11.2.3" [deps] +AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -22,6 +23,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JLD2Ext = "JLD2" [compat] +AcceleratedKernels = "0.4" Adapt = "4.0" GPUArraysCore = "= 0.2.0" JLD2 = "0.4, 0.5" diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl index 6e5d38df..d0221f25 100644 --- a/lib/JLArrays/src/JLArrays.jl +++ b/lib/JLArrays/src/JLArrays.jl @@ -347,6 +347,43 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br R end +## Base interface + +Base._accumulate!(op, output::AnyJLArray, input::AnyJLVector, dims::Nothing, init::Nothing) = + accumulate!(op, typed_data(output), typed_data(input); dims=1) + +Base._accumulate!(op, output::AnyJLArray, input::AnyJLArray, dims::Integer, init::Nothing) = + accumulate!(op, typed_data(output), typed_data(input); dims) + +Base._accumulate!(op, output::AnyJLArray, input::AnyJLVector, dims::Nothing, init::Some) = + accumulate!(op, typed_data(output), typed_data(input); dims=1, init=something(init)) + +Base._accumulate!(op, output::AnyJLArray, input::AnyJLArray, dims::Integer, init::Some) = + accumulate!(op, typed_data(output), typed_data(input); dims, init=something(init)) + +Base.accumulate_pairwise!(op, result::AnyJLVector, v::AnyJLVector) = accumulate!(op, result, v) + +# default behavior unless dims are specified by the user +function Base.accumulate(op, A::AnyJLArray; + dims::Union{Nothing,Integer}=nothing, kw...) + nt = values(kw) + if dims === nothing && !(A isa AbstractVector) + # This branch takes care of the cases not handled by `_accumulate!`. + return reshape(accumulate(op, typed_data(A)[:]; kw...), size(A)) + end + if isempty(kw) + out = similar(A, Base.promote_op(op, eltype(A), eltype(A))) + init = AK.neutral_element(op, eltype(out)) + elseif keys(nt) === (:init,) + out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A))) + init = nt.init + else + throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))")) + end + accumulate!(op, typed_data(out), typed_data(A); dims, init) +end + + ## KernelAbstractions interface KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend() diff --git a/src/GPUArrays.jl b/src/GPUArrays.jl index 8c1fc14e..1c991b17 100644 --- a/src/GPUArrays.jl +++ b/src/GPUArrays.jl @@ -16,6 +16,7 @@ using Reexport @reexport using GPUArraysCore using KernelAbstractions +import AcceleratedKernels as AK # device functionality include("device/abstractarray.jl") @@ -27,6 +28,7 @@ include("host/construction.jl") include("host/base.jl") include("host/indexing.jl") include("host/broadcast.jl") +include("host/accumulate.jl") include("host/mapreduce.jl") include("host/linalg.jl") include("host/math.jl") diff --git a/src/host/accumulate.jl b/src/host/accumulate.jl new file mode 100644 index 00000000..b694d690 --- /dev/null +++ b/src/host/accumulate.jl @@ -0,0 +1,35 @@ +## Base interface + +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Nothing) = + AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output))) + +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Nothing) = + AK.accumulate!(op, output, input, get_backend(output); dims, init=AK.neutral_element(op, eltype(output))) + +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUVector, dims::Nothing, init::Some) = + AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init)) + +Base._accumulate!(op, output::AnyGPUArray, input::AnyGPUArray, dims::Integer, init::Some) = + AK.accumulate!(op, output, input, get_backend(output); dims, init=something(init)) + +Base.accumulate_pairwise!(op, result::AnyGPUVector, v::AnyGPUVector) = accumulate!(op, result, v) + +# default behavior unless dims are specified by the user +function Base.accumulate(op, A::AnyGPUArray; + dims::Union{Nothing,Integer}=nothing, kw...) + nt = values(kw) + if dims === nothing && !(A isa AbstractVector) + # This branch takes care of the cases not handled by `_accumulate!`. + return reshape(AK.accumulate(op, A[:], get_backend(A); init = (:init in keys(kw) ? nt.init : AK.neutral_element(op, eltype(A)))), size(A)) + end + if isempty(kw) + out = similar(A, Base.promote_op(op, eltype(A), eltype(A))) + init = AK.neutral_element(op, eltype(out)) + elseif keys(nt) === (:init,) + out = similar(A, Base.promote_op(op, typeof(nt.init), eltype(A))) + init = nt.init + else + throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))")) + end + AK.accumulate!(op, out, A, get_backend(A); dims, init) +end diff --git a/test/Project.toml b/test/Project.toml index e6f21d04..46f44eb3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,3 +12,6 @@ REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[sources] +JLArrays = {path = "../lib/JLArrays"} diff --git a/test/testsuite.jl b/test/testsuite.jl index b48d7ccd..cc10b1f2 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -90,6 +90,7 @@ include("testsuite/indexing.jl") include("testsuite/base.jl") include("testsuite/vector.jl") include("testsuite/reductions.jl") +include("testsuite/accumulations.jl") include("testsuite/broadcasting.jl") include("testsuite/linalg.jl") include("testsuite/math.jl") diff --git a/test/testsuite/accumulations.jl b/test/testsuite/accumulations.jl new file mode 100644 index 00000000..b4afeac0 --- /dev/null +++ b/test/testsuite/accumulations.jl @@ -0,0 +1,108 @@ +@testsuite "accumulations" (AT, eltypes)->begin + @testset "$ET" for ET in eltypes + range = ET <: Real ? (ET(1):ET(10)) : ET + + # 1d arrays + for num_elems in 1:256 + @test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, num_elems)) + end + + for num_elems = rand(1:100, 10) + @test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, num_elems)) + end + + for _ in 1:10 # nd arrays reduced as 1d + n1 = rand(1:10) + n2 = rand(1:10) + n3 = rand(1:10) + @test compare(A->accumulate(+, A; init=zero(ET)), AT, rand(range, n1, n2, n3)) + end + + for num_elems = rand(1:100, 10) # init value + init = rand(range) + @test compare(A->accumulate(+, A; init), AT, rand(range, num_elems)) + end + + + # nd arrays + for dims in 1:4 # corner cases + for isize in 1:3 + for jsize in 1:3 + for ksize in 1:3 + @test compare(A->accumulate(+, A; dims, init=zero(ET)), AT, rand(range, isize, jsize, ksize)) + end + end + end + end + + for _ in 1:10 + for dims in 1:3 + n1 = rand(1:10) + n2 = rand(1:10) + n3 = rand(1:10) + @test compare(A->accumulate(+, A; dims, init=zero(ET)), AT, rand(range, n1, n2, n3)) + end + end + + for _ in 1:10 # init value + for dims in 1:3 + n1 = rand(1:10) + n2 = rand(1:10) + n3 = rand(1:10) + init = rand(range) + @test compare(A->accumulate(+, A; init, dims), AT, rand(range, n1, n2, n3)) + end + end + + # Larger containers to try and detect weird bugs + for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not + # Skip large tests on small datatypes + n >= 10000 && sizeof(real(ET)) <= 2 && continue + + @test compare(x->accumulate(+, x), AT, rand(range, n)) + @test compare(x->accumulate(+, x), AT, rand(range, n, 2)) + @test compare(Base.Fix2((x,y)->accumulate(+, x; init=y), rand(range)), AT, rand(range, n)) + end + + # in place + @test compare(x->(accumulate!(+, x, copy(x)); x), AT, rand(range, 2)) + + @test_throws ArgumentError("accumulate does not support the keyword arguments [:bad_kwarg]") accumulate(+, AT(rand(ET, 10)); bad_kwarg="bad") + end +end + +@testsuite "accumulations/cumsum & cumprod" (AT, eltypes)->begin + @test compare(cumsum, AT, rand(Bool, 16)) + + @testset "$ET" for ET in eltypes + range = ET <: Real ? (ET(1):ET(10)) : ET + + # cumsum + for num_elems in rand(1:100, 10) + @test compare(A->cumsum(A; dims=1), AT, rand(range, num_elems)) + end + + for _ in 1:10 + for dims in 1:3 + n1 = rand(1:10) + n2 = rand(1:10) + n3 = rand(1:10) + @test compare(A->cumsum(A; dims), AT, rand(range, n1, n2, n3)) + end + end + + + # cumprod + range = ET <: Real ? (ET(1):ET(10)) : ET + @test compare(A->cumprod(A; dims=1), AT, ones(ET, 100_000)) + + for _ in 1:10 + for dims in 1:3 + n1 = rand(1:10) + n2 = rand(1:10) + n3 = rand(1:10) + @test compare(A->cumprod(A; dims), AT, rand(range, n1, n2, n3)) + end + end + end +end