Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AcceleratedKernels"
uuid = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
authors = ["Andrei-Leonard Nicusan <[email protected]> and contributors"]
version = "0.3.1"
version = "0.3.2"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand All @@ -23,7 +23,7 @@ AcceleratedKernelsoneAPIExt = "oneAPI"
[compat]
ArgCheck = "2"
GPUArrays = "10, 11"
KernelAbstractions = "0.9"
KernelAbstractions = "0.9.34"
Markdown = "1"
Metal = "1"
OhMyThreads = "0.7"
Expand Down
5 changes: 2 additions & 3 deletions prototype/reduce_nd_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ end
# Make array with highly unequal per-axis sizes
s = MtlArray(rand(Int32(1):Int32(100), 10, 1_000_000))
AK.reduce(+, s, init=zero(eltype(s)))
ret

# Correctness
@assert sum_base(s, dims=1) == sum_ak(s, dims=1)
Expand All @@ -34,11 +33,11 @@ ret
# Benchmarks
println("\nReduction over small axis - AK vs Base")
display(@benchmark sum_ak($s, dims=1))
display(@benchmark sum_base($s, dims=1))
# display(@benchmark sum_base($s, dims=1))

println("\nReduction over long axis - AK vs Base")
display(@benchmark sum_ak($s, dims=2))
display(@benchmark sum_base($s, dims=2))
# display(@benchmark sum_base($s, dims=2))



Expand Down
18 changes: 11 additions & 7 deletions src/accumulate/accumulate_1d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ const ACC_FLAG_P::UInt8 = 1 # Only current block's prefix available
end


@kernel cpu=false inbounds=true function _accumulate_block!(op, v, init, neutral,
inclusive,
flags, prefixes) # one per block

@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_block!(
op, v, init, neutral,
inclusive,
flags, prefixes, # one per block
)
# NOTE: shmem_size MUST be greater than 2 * block_size
# NOTE: block_size MUST be a power of 2
len = length(v)
Expand Down Expand Up @@ -147,7 +148,9 @@ end
end


@kernel cpu=false inbounds=true function _accumulate_previous!(op, v, flags, @Const(prefixes))
@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_previous!(
op, v, flags, @Const(prefixes),
)

len = length(v)
block_size = @groupsize()[1]
Expand Down Expand Up @@ -200,8 +203,9 @@ end
end


@kernel cpu=false inbounds=true function _accumulate_previous_coupled_preblocks!(op, v, prefixes)

@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_previous_coupled_preblocks!(
op, v, prefixes,
)
# No decoupled lookback
len = length(v)
block_size = @groupsize()[1]
Expand Down
10 changes: 6 additions & 4 deletions src/accumulate/accumulate_nd.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@kernel inbounds=true cpu=false function _accumulate_nd_by_thread!(v, op, init, dims, inclusive)

@kernel inbounds=true cpu=false unsafe_indices=true function _accumulate_nd_by_thread!(
v, op, init, dims, inclusive,
)
# One thread per outer dimension element, when there are more outer elements than in the
# reduced dim e.g. accumulate(+, rand(3, 1000), dims=1) => only 3 elements in the accumulated
# dim
Expand Down Expand Up @@ -57,8 +58,9 @@
end


@kernel inbounds=true cpu=false function _accumulate_nd_by_block!(v, op, init, neutral, dims, inclusive)

@kernel inbounds=true cpu=false unsafe_indices=true function _accumulate_nd_by_block!(
v, op, init, neutral, dims, inclusive,
)
# NOTE: shmem_size MUST be greater than 2 * block_size
# NOTE: block_size MUST be a power of 2

Expand Down
17 changes: 13 additions & 4 deletions src/foreachindex.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
@kernel cpu=false inbounds=true function _forindices_global!(f, indices)
i = @index(Global, Linear)
f(indices[i])
@kernel inbounds=true cpu=false unsafe_indices=true function _forindices_global!(f, indices)

# Calculate global index
N = @groupsize()[1]
iblock = @index(Group, Linear)
ithread = @index(Local, Linear)
i = ithread + (iblock - 0x1) * N

if i <= length(indices)
f(indices[i])
end
end


Expand All @@ -13,7 +21,8 @@ function _forindices_gpu(
)
# GPU implementation
@argcheck block_size > 0
_forindices_global!(backend, block_size)(f, indices, ndrange=length(indices))
blocks = (length(indices) + block_size - 1) ÷ block_size
_forindices_global!(backend, block_size)(f, indices, ndrange=(block_size * blocks,))
nothing
end

Expand Down
2 changes: 1 addition & 1 deletion src/reduce/mapreduce_1d.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@kernel inbounds=true cpu=false function _mapreduce_block!(@Const(src), dst, f, op, neutral)
@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_block!(@Const(src), dst, f, op, neutral)

@uniform N = @groupsize()[1]
sdata = @localmem eltype(dst) (N,)
Expand Down
171 changes: 91 additions & 80 deletions src/reduce/mapreduce_nd.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
@kernel inbounds=true cpu=false function _mapreduce_nd_by_thread!(@Const(src), dst, f, op, init, dims)

@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_thread!(
@Const(src),
dst,
f,
op,
init,
dims,
)
# One thread per output element, when there are more outer elements than in the reduced dim
# e.g. reduce(+, rand(3, 1000), dims=1) => only 3 elements in the reduced dim
src_sizes = size(src)
Expand Down Expand Up @@ -64,8 +70,15 @@
end


@kernel inbounds=true cpu=false function _mapreduce_nd_by_block!(@Const(src), dst, f, op, init, neutral, dims)

@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_block!(
@Const(src),
dst,
f,
op,
init,
neutral,
dims,
)
# One block per output element, when there are more elements in the reduced dim than in outer
# e.g. reduce(+, rand(3, 1000), dims=2) => only 3 elements in outer dimensions
src_sizes = size(src)
Expand All @@ -90,86 +103,84 @@ end
iblock = @index(Group, Linear) - 0x1
ithread = @index(Local, Linear) - 0x1

# Each block handles one output element
if iblock < output_size

# # Sometimes slightly faster method using additional memory with
# # output_idx = @private typeof(iblock) (ndims,)
# tmp = iblock
# KernelAbstractions.Extras.@unroll for i in ndims:-1:1
# output_idx[i] = tmp ÷ dst_strides[i]
# tmp = tmp % dst_strides[i]
# end
# # Compute the base index in src (excluding the reduced axis)
# input_base_idx = 0
# KernelAbstractions.Extras.@unroll for i in 1:ndims
# i == dims && continue
# input_base_idx += output_idx[i] * src_strides[i]
# end

# Compute the base index in src (excluding the reduced axis)
input_base_idx = typeof(ithread)(0)
tmp = iblock
KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16
if i != dims
input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i]
end
tmp = tmp % dst_strides[i]
# Each block handles one output element - thus, iblock ∈ [0, output_size)

# # Sometimes slightly faster method using additional memory with
# # output_idx = @private typeof(iblock) (ndims,)
# tmp = iblock
# KernelAbstractions.Extras.@unroll for i in ndims:-1:1
# output_idx[i] = tmp ÷ dst_strides[i]
# tmp = tmp % dst_strides[i]
# end
# # Compute the base index in src (excluding the reduced axis)
# input_base_idx = 0
# KernelAbstractions.Extras.@unroll for i in 1:ndims
# i == dims && continue
# input_base_idx += output_idx[i] * src_strides[i]
# end

# Compute the base index in src (excluding the reduced axis)
input_base_idx = typeof(ithread)(0)
tmp = iblock
KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16
if i != dims
input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i]
end
tmp = tmp % dst_strides[i]
end

# We have a block of threads to process the whole reduced dimension. First do pre-reduction
# in strides of N
partial = neutral
i = ithread
while i < reduce_size
src_idx = input_base_idx + i * src_strides[dims]
partial = op(partial, f(src[src_idx + 0x1]))
i += N
end
# We have a block of threads to process the whole reduced dimension. First do pre-reduction
# in strides of N
partial = neutral
i = ithread
while i < reduce_size
src_idx = input_base_idx + i * src_strides[dims]
partial = op(partial, f(src[src_idx + 0x1]))
i += N
end

# Store partial result in shared memory; now we are down to a single block to reduce within
sdata[ithread + 0x1] = partial
@synchronize()
# Store partial result in shared memory; now we are down to a single block to reduce within
sdata[ithread + 0x1] = partial
@synchronize()

if N >= 512u16
ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]))
@synchronize()
end
if N >= 256u16
ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]))
@synchronize()
end
if N >= 128u16
ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]))
@synchronize()
end
if N >= 64u16
ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]))
@synchronize()
end
if N >= 32u16
ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]))
@synchronize()
end
if N >= 16u16
ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]))
@synchronize()
end
if N >= 8u16
ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]))
@synchronize()
end
if N >= 4u16
ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]))
@synchronize()
end
if N >= 2u16
ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]))
@synchronize()
end
if ithread == 0x0
dst[iblock + 0x1] = op(init, sdata[0x1])
end
if N >= 512u16
ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]))
@synchronize()
end
if N >= 256u16
ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]))
@synchronize()
end
if N >= 128u16
ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]))
@synchronize()
end
if N >= 64u16
ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]))
@synchronize()
end
if N >= 32u16
ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]))
@synchronize()
end
if N >= 16u16
ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]))
@synchronize()
end
if N >= 8u16
ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]))
@synchronize()
end
if N >= 4u16
ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]))
@synchronize()
end
if N >= 2u16
ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]))
@synchronize()
end
if ithread == 0x0
dst[iblock + 0x1] = op(init, sdata[0x1])
end
end

Expand Down
7 changes: 4 additions & 3 deletions src/sort/merge_sort.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@kernel inbounds=true function _merge_sort_block!(vec, comp)
@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_block!(vec, comp)

@uniform N = @groupsize()[1]
s_buf = @localmem eltype(vec) (N * 0x2,)
Expand Down Expand Up @@ -75,8 +75,9 @@
end


@kernel inbounds=true function _merge_sort_global!(@Const(vec_in), vec_out, comp, half_size_group)

@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_global!(
@Const(vec_in), vec_out, comp, half_size_group,
)
len = length(vec_in)
N = @groupsize()[1]

Expand Down
10 changes: 6 additions & 4 deletions src/sort/merge_sort_by_key.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@kernel inbounds=true function _merge_sort_by_key_block!(keys, values, comp)
@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_by_key_block!(keys, values, comp)

@uniform N = @groupsize()[1]
s_keys = @localmem eltype(keys) (N * 0x2,)
Expand Down Expand Up @@ -97,9 +97,11 @@
end


@kernel inbounds=true function _merge_sort_by_key_global!(@Const(keys_in), keys_out,
@Const(values_in), values_out,
comp, half_size_group)
@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_by_key_global!(
@Const(keys_in), keys_out,
@Const(values_in), values_out,
comp, half_size_group,
)

len = length(keys_in)
N = @groupsize()[1]
Expand Down
Loading