Skip to content

Commit e88c427

Browse files
committed
simplify accumulate-family definitions
Also remove redundant `kwarg=kwarg` definitions for `accumulate`
1 parent 4959366 commit e88c427

File tree

2 files changed

+26
-84
lines changed

2 files changed

+26
-84
lines changed

ext/AcceleratedKernelsMetalExt.jl

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,14 @@ import AcceleratedKernels as AK
1010
function AK.accumulate!(
1111
op, v::AbstractArray, backend::MetalBackend;
1212
init,
13-
neutral=AK.neutral_element(op, eltype(v)),
14-
dims::Union{Nothing, Int}=nothing,
15-
inclusive::Bool=true,
16-
17-
# CPU settings - not used
18-
max_tasks::Int=Threads.nthreads(),
19-
min_elems::Int=1,
20-
21-
# Algorithm choice
13+
# Algorithm choice is the only differing default
2214
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),
23-
24-
# GPU settings
25-
block_size::Int=256,
26-
temp::Union{Nothing, AbstractArray}=nothing,
27-
temp_flags::Union{Nothing, AbstractArray}=nothing,
15+
kwargs...
2816
)
2917
AK._accumulate_impl!(
30-
op, v, backend,
31-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
32-
alg=alg,
33-
block_size=block_size, temp=temp, temp_flags=temp_flags,
18+
op, v, backend;
19+
init, alg,
20+
kwargs...
3421
)
3522
end
3623

@@ -39,28 +26,15 @@ end
3926
function AK.accumulate!(
4027
op, dst::AbstractArray, src::AbstractArray, backend::MetalBackend;
4128
init,
42-
neutral=AK.neutral_element(op, eltype(dst)),
43-
dims::Union{Nothing, Int}=nothing,
44-
inclusive::Bool=true,
45-
46-
# CPU settings - not used
47-
max_tasks::Int=Threads.nthreads(),
48-
min_elems::Int=1,
49-
50-
# Algorithm choice
29+
# Algorithm choice is the only differing default
5130
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),
52-
53-
# GPU settings
54-
block_size::Int=256,
55-
temp::Union{Nothing, AbstractArray}=nothing,
56-
temp_flags::Union{Nothing, AbstractArray}=nothing,
31+
kwargs...
5732
)
5833
copyto!(dst, src)
5934
AK._accumulate_impl!(
60-
op, dst, backend,
61-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
62-
alg=alg,
63-
block_size=block_size, temp=temp, temp_flags=temp_flags,
35+
op, dst, backend;
36+
init, alg,
37+
kwargs...
6438
)
6539
end
6640

src/accumulate/accumulate.jl

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -124,58 +124,26 @@ AK.accumulate!(+, v, alg=AK.ScanPrefixes())
124124
function accumulate!(
125125
op, v::AbstractArray, backend::Backend=get_backend(v);
126126
init,
127-
neutral=neutral_element(op, eltype(v)),
128-
dims::Union{Nothing, Int}=nothing,
129-
inclusive::Bool=true,
130-
131-
# CPU settings
132-
max_tasks::Int=Threads.nthreads(),
133-
min_elems::Int=2,
134-
135-
# Algorithm choice
136-
alg::AccumulateAlgorithm=DecoupledLookback(),
137-
138-
# GPU settings
139-
block_size::Int=256,
140-
temp::Union{Nothing, AbstractArray}=nothing,
141-
temp_flags::Union{Nothing, AbstractArray}=nothing,
127+
kwargs...
142128
)
143129
_accumulate_impl!(
144-
op, v, backend,
145-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
146-
max_tasks=max_tasks, min_elems=min_elems,
147-
alg=alg,
148-
block_size=block_size, temp=temp, temp_flags=temp_flags,
130+
op, v, backend;
131+
init,
132+
kwargs...
149133
)
150134
end
151135

152136

153137
function accumulate!(
154138
op, dst::AbstractArray, src::AbstractArray, backend::Backend=get_backend(dst);
155139
init,
156-
neutral=neutral_element(op, eltype(dst)),
157-
dims::Union{Nothing, Int}=nothing,
158-
inclusive::Bool=true,
159-
160-
# CPU settings
161-
max_tasks::Int=Threads.nthreads(),
162-
min_elems::Int=2,
163-
164-
# Algorithm choice
165-
alg::AccumulateAlgorithm=DecoupledLookback(),
166-
167-
# GPU settings
168-
block_size::Int=256,
169-
temp::Union{Nothing, AbstractArray}=nothing,
170-
temp_flags::Union{Nothing, AbstractArray}=nothing,
140+
kwargs...
171141
)
172142
copyto!(dst, src)
173143
_accumulate_impl!(
174-
op, dst, backend,
175-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
176-
max_tasks=max_tasks, min_elems=min_elems,
177-
alg=alg,
178-
block_size=block_size, temp=temp, temp_flags=temp_flags,
144+
op, dst, backend;
145+
init,
146+
kwargs...
179147
)
180148
end
181149

@@ -200,17 +168,17 @@ function _accumulate_impl!(
200168
)
201169
if isnothing(dims)
202170
return accumulate_1d!(
203-
op, v, backend, alg,
204-
init=init, neutral=neutral, inclusive=inclusive,
205-
max_tasks=max_tasks, min_elems=min_elems,
206-
block_size=block_size, temp=temp, temp_flags=temp_flags,
171+
op, v, backend, alg;
172+
init, neutral, inclusive,
173+
max_tasks, min_elems,
174+
block_size, temp, temp_flags,
207175
)
208176
else
209177
return accumulate_nd!(
210-
op, v, backend,
211-
init=init, neutral=neutral, dims=dims, inclusive=inclusive,
212-
max_tasks=max_tasks, min_elems=min_elems,
213-
block_size=block_size,
178+
op, v, backend;
179+
init, neutral, dims, inclusive,
180+
max_tasks, min_elems,
181+
block_size,
214182
)
215183
end
216184
end

0 commit comments

Comments
 (0)