diff --git a/ext/AcceleratedKernelsMetalExt.jl b/ext/AcceleratedKernelsMetalExt.jl index 2f10894..326429a 100644 --- a/ext/AcceleratedKernelsMetalExt.jl +++ b/ext/AcceleratedKernelsMetalExt.jl @@ -10,27 +10,14 @@ import AcceleratedKernels as AK function AK.accumulate!( op, v::AbstractArray, backend::MetalBackend; init, - neutral=AK.neutral_element(op, eltype(v)), - dims::Union{Nothing, Int}=nothing, - inclusive::Bool=true, - - # CPU settings - not used - max_tasks::Int=Threads.nthreads(), - min_elems::Int=1, - - # Algorithm choice + # Algorithm choice is the only differing default alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(), - - # GPU settings - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + kwargs... ) AK._accumulate_impl!( - op, v, backend, - init=init, neutral=neutral, dims=dims, inclusive=inclusive, - alg=alg, - block_size=block_size, temp=temp, temp_flags=temp_flags, + op, v, backend; + init, alg, + kwargs... ) end @@ -39,28 +26,15 @@ end function AK.accumulate!( op, dst::AbstractArray, src::AbstractArray, backend::MetalBackend; init, - neutral=AK.neutral_element(op, eltype(dst)), - dims::Union{Nothing, Int}=nothing, - inclusive::Bool=true, - - # CPU settings - not used - max_tasks::Int=Threads.nthreads(), - min_elems::Int=1, - - # Algorithm choice + # Algorithm choice is the only differing default alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(), - - # GPU settings - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + kwargs... ) copyto!(dst, src) AK._accumulate_impl!( - op, dst, backend, - init=init, neutral=neutral, dims=dims, inclusive=inclusive, - alg=alg, - block_size=block_size, temp=temp, temp_flags=temp_flags, + op, dst, backend; + init, alg, + kwargs... ) end diff --git a/src/accumulate/accumulate.jl b/src/accumulate/accumulate.jl index 1738a37..f87b21e 100644 --- a/src/accumulate/accumulate.jl +++ b/src/accumulate/accumulate.jl @@ -124,28 +124,12 @@ AK.accumulate!(+, v, alg=AK.ScanPrefixes()) function accumulate!( op, v::AbstractArray, backend::Backend=get_backend(v); init, - neutral=neutral_element(op, eltype(v)), - dims::Union{Nothing, Int}=nothing, - inclusive::Bool=true, - - # CPU settings - max_tasks::Int=Threads.nthreads(), - min_elems::Int=2, - - # Algorithm choice - alg::AccumulateAlgorithm=DecoupledLookback(), - - # GPU settings - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + kwargs... ) _accumulate_impl!( - op, v, backend, - init=init, neutral=neutral, dims=dims, inclusive=inclusive, - max_tasks=max_tasks, min_elems=min_elems, - alg=alg, - block_size=block_size, temp=temp, temp_flags=temp_flags, + op, v, backend; + init, + kwargs... ) end @@ -153,29 +137,13 @@ end function accumulate!( op, dst::AbstractArray, src::AbstractArray, backend::Backend=get_backend(dst); init, - neutral=neutral_element(op, eltype(dst)), - dims::Union{Nothing, Int}=nothing, - inclusive::Bool=true, - - # CPU settings - max_tasks::Int=Threads.nthreads(), - min_elems::Int=2, - - # Algorithm choice - alg::AccumulateAlgorithm=DecoupledLookback(), - - # GPU settings - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + kwargs... ) copyto!(dst, src) _accumulate_impl!( - op, dst, backend, - init=init, neutral=neutral, dims=dims, inclusive=inclusive, - max_tasks=max_tasks, min_elems=min_elems, - alg=alg, - block_size=block_size, temp=temp, temp_flags=temp_flags, + op, dst, backend; + init, + kwargs... ) end @@ -200,17 +168,17 @@ function _accumulate_impl!( ) if isnothing(dims) return accumulate_1d!( - op, v, backend, alg, - init=init, neutral=neutral, inclusive=inclusive, - max_tasks=max_tasks, min_elems=min_elems, - block_size=block_size, temp=temp, temp_flags=temp_flags, + op, v, backend, alg; + init, neutral, inclusive, + max_tasks, min_elems, + block_size, temp, temp_flags, ) else return accumulate_nd!( - op, v, backend, - init=init, neutral=neutral, dims=dims, inclusive=inclusive, - max_tasks=max_tasks, min_elems=min_elems, - block_size=block_size, + op, v, backend; + init, neutral, dims, inclusive, + max_tasks, min_elems, + block_size, ) end end @@ -242,31 +210,15 @@ Out-of-place version of [`accumulate!`](@ref). function accumulate( op, v::AbstractArray, backend::Backend=get_backend(v); init, - neutral=neutral_element(op, eltype(v)), - dims::Union{Nothing, Int}=nothing, - inclusive::Bool=true, - - # CPU settings - max_tasks::Int=Threads.nthreads(), - min_elems::Int=2, - - # Algorithm choice - alg::AccumulateAlgorithm=DecoupledLookback(), - - # GPU settings - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + kwargs... ) dst_type = Base.promote_op(op, eltype(v), typeof(init)) vcopy = similar(v, dst_type) copyto!(vcopy, v) accumulate!( op, vcopy, backend; - init=init, neutral=neutral, dims=dims, inclusive=inclusive, - max_tasks=max_tasks, min_elems=min_elems, - alg=alg, - block_size=block_size, temp=temp, temp_flags=temp_flags, + init, + kwargs... ) vcopy end diff --git a/src/accumulate/accumulate_1d_cpu.jl b/src/accumulate/accumulate_1d_cpu.jl index 2066697..eda07ae 100644 --- a/src/accumulate/accumulate_1d_cpu.jl +++ b/src/accumulate/accumulate_1d_cpu.jl @@ -2,16 +2,16 @@ function accumulate_1d!( op, v::AbstractArray, backend::CPU, alg; init, neutral, - inclusive::Bool=true, + inclusive::Bool, # CPU settings - max_tasks::Int=Threads.nthreads(), - min_elems::Int=2, + max_tasks::Int, + min_elems::Int, # GPU settings - not used - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + block_size::Int, + temp::Union{Nothing, AbstractArray}, + temp_flags::Union{Nothing, AbstractArray}, ) # Trivial case if length(v) == 0 diff --git a/src/accumulate/accumulate_1d_gpu.jl b/src/accumulate/accumulate_1d_gpu.jl index b66f443..be3ee59 100644 --- a/src/accumulate/accumulate_1d_gpu.jl +++ b/src/accumulate/accumulate_1d_gpu.jl @@ -252,16 +252,16 @@ function accumulate_1d!( op, v::AbstractArray, backend::GPU, ::DecoupledLookback; init, neutral, - inclusive::Bool=true, + inclusive::Bool, # CPU settings - not used - max_tasks::Int=Threads.nthreads(), - min_elems::Int=1, + max_tasks::Int, + min_elems::Int, # GPU settings - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + block_size::Int, + temp::Union{Nothing, AbstractArray}, + temp_flags::Union{Nothing, AbstractArray}, ) # Correctness checks @argcheck block_size > 0 @@ -311,16 +311,16 @@ function accumulate_1d!( op, v::AbstractArray, backend::GPU, ::ScanPrefixes; init, neutral, - inclusive::Bool=true, + inclusive::Bool, # CPU settings - not used - max_tasks::Int=Threads.nthreads(), - min_elems::Int=1, + max_tasks::Int, + min_elems::Int, # GPU settings - block_size::Int=256, - temp::Union{Nothing, AbstractArray}=nothing, - temp_flags::Union{Nothing, AbstractArray}=nothing, + block_size::Int, + temp::Union{Nothing, AbstractArray}, + temp_flags::Union{Nothing, AbstractArray}, ) # Correctness checks @argcheck block_size > 0 diff --git a/src/accumulate/accumulate_nd.jl b/src/accumulate/accumulate_nd.jl index 219f7b0..8aaa83e 100644 --- a/src/accumulate/accumulate_nd.jl +++ b/src/accumulate/accumulate_nd.jl @@ -1,16 +1,16 @@ function accumulate_nd!( op, v::AbstractArray, backend::Backend; init, - neutral=neutral_element(op, eltype(v)), + neutral, dims::Int, - inclusive::Bool=true, + inclusive::Bool, # CPU settings - max_tasks::Int=Threads.nthreads(), - min_elems::Int=1, + max_tasks::Int, + min_elems::Int, # GPU settings - block_size::Int=256, + block_size::Int, ) # Correctness checks @argcheck block_size > 0 diff --git a/test/accumulate.jl b/test/accumulate.jl index a8b425e..b50ba47 100644 --- a/test/accumulate.jl +++ b/test/accumulate.jl @@ -81,6 +81,9 @@ AK.accumulate!(+, y; init=Int32(init), inclusive=false) @test all(Array(y) .== 10:19) + # Test that undefined kwargs are not accepted + @test_throws MethodError AK.accumulate(+, y; init=10, dims=2, inclusive=false, bad=:kwarg) + # Testing different settings AK.accumulate!(+, array_from_host(ones(Int32, 1000)), init=0, inclusive=false, block_size=128, @@ -186,6 +189,9 @@ end sh = Array(s) @test all([sh[i, :] == 10:19 for i in 1:10]) + # Test that undefined kwargs are not accepted + @test_throws MethodError AK.accumulate(+, v; init=10, dims=2, inclusive=false, bad=:kwarg) + # Testing different settings AK.accumulate( (x, y) -> x + 1,