Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 0 additions & 69 deletions ext/AcceleratedKernelsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,73 +38,4 @@ function AK.accumulate!(
)
end


function AK.cumsum(
src::AbstractArray, backend::MetalBackend;
init=zero(eltype(src)),
neutral=zero(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# CPU settings - not used
max_tasks::Int=Threads.nthreads(),
min_elems::Int=1,

# Algorithm choice
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
temp_flags::Union{Nothing, AbstractArray}=nothing,
)
AK.accumulate(
+, src, backend;
init=init,
neutral=neutral,
dims=dims,
inclusive=true,

alg=alg,

block_size=block_size,
temp=temp,
temp_flags=temp_flags,
)
end


function AK.cumprod(
src::AbstractArray, backend::MetalBackend;
init=one(eltype(src)),
neutral=one(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# CPU settings - not used
max_tasks::Int=Threads.nthreads(),
min_elems::Int=1,

# Algorithm choice
alg::AK.AccumulateAlgorithm=AK.ScanPrefixes(),

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
temp_flags::Union{Nothing, AbstractArray}=nothing,
)
AK.accumulate(
*, src, backend;
init=init,
neutral=neutral,
dims=dims,
inclusive=true,

alg=alg,

block_size=block_size,
temp=temp,
temp_flags=temp_flags,
)
end


end # module AcceleratedKernelsMetalExt
28 changes: 6 additions & 22 deletions ext/AcceleratedKernelsoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,12 @@ function AK.any(

# Algorithm choice
alg::AK.PredicatesAlgorithm=AK.MapReduce(),

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
kwargs...
)
AK._any_impl(
pred, v, backend;
alg=alg,
max_tasks=max_tasks,
min_elems=min_elems,
block_size=block_size,
alg,
kwargs...
)
end

Expand All @@ -35,20 +27,12 @@ function AK.all(

# Algorithm choice
alg::AK.PredicatesAlgorithm=AK.MapReduce(),

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
kwargs...
)
AK._all_impl(
pred, v, backend;
alg=alg,
max_tasks=max_tasks,
min_elems=min_elems,
block_size=block_size,
alg,
kwargs...
)
end

Expand Down
3 changes: 1 addition & 2 deletions src/accumulate/accumulate_1d_cpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ function accumulate_1d!(
if itask == 1
_accumulate_1d_cpu_section!(
op, @view(v[irange]);
init=init,
inclusive=inclusive,
init, inclusive,
)
else
# Later sections should always be inclusively accumulated
Expand Down
174 changes: 24 additions & 150 deletions src/arithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,12 @@ s = AK.sum(m, dims=2, temp=temp)
function sum(
src::AbstractArray, backend::Backend=get_backend(src);
init=zero(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
switch_below::Int=0,
kwargs...
)
reduce(
+, src, backend;
init=init,
dims=dims,

max_tasks=max_tasks,
min_elems=min_elems,

block_size=block_size,
temp=temp,
switch_below=switch_below,
init,
kwargs...
)
end

Expand Down Expand Up @@ -116,28 +100,12 @@ p = AK.prod(m, dims=2, temp=temp)
function prod(
src::AbstractArray, backend::Backend=get_backend(src);
init=one(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
switch_below::Int=0,
kwargs...
)
reduce(
*, src, backend;
init=init,
dims=dims,

max_tasks=max_tasks,
min_elems=min_elems,

block_size=block_size,
temp=temp,
switch_below=switch_below,
init,
kwargs...
)
end

Expand Down Expand Up @@ -188,28 +156,12 @@ m = AK.maximum(m, dims=2, temp=temp)
function maximum(
src::AbstractArray, backend::Backend=get_backend(src);
init=typemin(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
switch_below::Int=0,
kwargs...
)
reduce(
max, src, backend;
init=init,
dims=dims,

max_tasks=max_tasks,
min_elems=min_elems,

block_size=block_size,
temp=temp,
switch_below=switch_below,
init,
kwargs...
)
end

Expand Down Expand Up @@ -260,28 +212,12 @@ m = AK.minimum(m, dims=2, temp=temp)
function minimum(
src::AbstractArray, backend::Backend=get_backend(src);
init=typemax(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
switch_below::Int=0,
kwargs...
)
reduce(
min, src, backend;
init=init,
dims=dims,

max_tasks=max_tasks,
min_elems=min_elems,

block_size=block_size,
temp=temp,
switch_below=switch_below,
init,
kwargs...
)
end

Expand Down Expand Up @@ -338,59 +274,27 @@ c = AK.count(m, dims=2, temp=temp)
function count(
src::AbstractArray, backend::Backend=get_backend(src);
init=0,
dims::Union{Nothing, Int}=nothing,

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
switch_below::Int=0,
kwargs...
)
mapreduce(
x -> x ? one(typeof(init)) : zero(typeof(init)), +, src, backend;
init=init,
init,
neutral=zero(typeof(init)),
dims=dims,

max_tasks=max_tasks,
min_elems=min_elems,

block_size=block_size,
temp=temp,
switch_below=switch_below,
kwargs...
)
end


function count(
f, src::AbstractArray, backend::Backend=get_backend(src);
init=0,
dims::Union{Nothing, Int}=nothing,

# CPU settings
max_tasks=Threads.nthreads(),
min_elems=1,

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
switch_below::Int=0,
kwargs...
)
mapreduce(
x -> f(x) ? one(typeof(init)) : zero(typeof(init)), +, src, backend;
init=init,
init,
neutral=zero(typeof(init)),
dims=dims,

max_tasks=max_tasks,
min_elems=min_elems,

block_size=block_size,
temp=temp,
switch_below=switch_below,
kwargs...
)
end

Expand Down Expand Up @@ -437,28 +341,13 @@ function cumsum(
src::AbstractArray, backend::Backend=get_backend(src);
init=zero(eltype(src)),
neutral=zero(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# Algorithm choice
alg::AccumulateAlgorithm=DecoupledLookback(),

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
temp_flags::Union{Nothing, AbstractArray}=nothing,
kwargs...
)
accumulate(
+, src, backend;
init=init,
neutral=neutral,
dims=dims,
init, neutral,
inclusive=true,

alg=alg,

block_size=block_size,
temp=temp,
temp_flags=temp_flags,
kwargs...
)
end

Expand Down Expand Up @@ -505,27 +394,12 @@ function cumprod(
src::AbstractArray, backend::Backend=get_backend(src);
init=one(eltype(src)),
neutral=one(eltype(src)),
dims::Union{Nothing, Int}=nothing,

# Algorithm choice
alg::AccumulateAlgorithm=DecoupledLookback(),

# GPU settings
block_size::Int=256,
temp::Union{Nothing, AbstractArray}=nothing,
temp_flags::Union{Nothing, AbstractArray}=nothing,
kwargs...
)
accumulate(
*, src, backend;
init=init,
neutral=neutral,
dims=dims,
init, neutral,
inclusive=true,

alg=alg,

block_size=block_size,
temp=temp,
temp_flags=temp_flags,
kwargs...
)
end
Loading
Loading