diff --git a/Project.toml b/Project.toml index b51bb27..00c19b0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AcceleratedKernels" uuid = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" authors = ["Andrei-Leonard Nicusan and contributors"] -version = "0.3.1" +version = "0.3.2" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -23,7 +23,7 @@ AcceleratedKernelsoneAPIExt = "oneAPI" [compat] ArgCheck = "2" GPUArrays = "10, 11" -KernelAbstractions = "0.9" +KernelAbstractions = "0.9.34" Markdown = "1" Metal = "1" OhMyThreads = "0.7" diff --git a/prototype/reduce_nd_test.jl b/prototype/reduce_nd_test.jl index 9cf151a..b50bd53 100644 --- a/prototype/reduce_nd_test.jl +++ b/prototype/reduce_nd_test.jl @@ -24,7 +24,6 @@ end # Make array with highly unequal per-axis sizes s = MtlArray(rand(Int32(1):Int32(100), 10, 1_000_000)) AK.reduce(+, s, init=zero(eltype(s))) -ret # Correctness @assert sum_base(s, dims=1) == sum_ak(s, dims=1) @@ -34,11 +33,11 @@ ret # Benchmarks println("\nReduction over small axis - AK vs Base") display(@benchmark sum_ak($s, dims=1)) -display(@benchmark sum_base($s, dims=1)) +# display(@benchmark sum_base($s, dims=1)) println("\nReduction over long axis - AK vs Base") display(@benchmark sum_ak($s, dims=2)) -display(@benchmark sum_base($s, dims=2)) +# display(@benchmark sum_base($s, dims=2)) diff --git a/src/accumulate/accumulate_1d.jl b/src/accumulate/accumulate_1d.jl index 327b0c4..c54c244 100644 --- a/src/accumulate/accumulate_1d.jl +++ b/src/accumulate/accumulate_1d.jl @@ -12,10 +12,11 @@ const ACC_FLAG_P::UInt8 = 1 # Only current block's prefix available end -@kernel cpu=false inbounds=true function _accumulate_block!(op, v, init, neutral, - inclusive, - flags, prefixes) # one per block - +@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_block!( + op, v, init, neutral, + inclusive, + flags, prefixes, # one per block +) # NOTE: shmem_size MUST be greater than 2 * block_size # NOTE: block_size MUST be a power of 2 len = length(v) @@ -147,7 +148,9 @@ end end -@kernel cpu=false inbounds=true function _accumulate_previous!(op, v, flags, @Const(prefixes)) +@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_previous!( + op, v, flags, @Const(prefixes), +) len = length(v) block_size = @groupsize()[1] @@ -200,8 +203,9 @@ end end -@kernel cpu=false inbounds=true function _accumulate_previous_coupled_preblocks!(op, v, prefixes) - +@kernel cpu=false inbounds=true unsafe_indices=true function _accumulate_previous_coupled_preblocks!( + op, v, prefixes, +) # No decoupled lookback len = length(v) block_size = @groupsize()[1] diff --git a/src/accumulate/accumulate_nd.jl b/src/accumulate/accumulate_nd.jl index 213b948..52d0fa8 100644 --- a/src/accumulate/accumulate_nd.jl +++ b/src/accumulate/accumulate_nd.jl @@ -1,5 +1,6 @@ -@kernel inbounds=true cpu=false function _accumulate_nd_by_thread!(v, op, init, dims, inclusive) - +@kernel inbounds=true cpu=false unsafe_indices=true function _accumulate_nd_by_thread!( + v, op, init, dims, inclusive, +) # One thread per outer dimension element, when there are more outer elements than in the # reduced dim e.g. accumulate(+, rand(3, 1000), dims=1) => only 3 elements in the accumulated # dim @@ -57,8 +58,9 @@ end -@kernel inbounds=true cpu=false function _accumulate_nd_by_block!(v, op, init, neutral, dims, inclusive) - +@kernel inbounds=true cpu=false unsafe_indices=true function _accumulate_nd_by_block!( + v, op, init, neutral, dims, inclusive, +) # NOTE: shmem_size MUST be greater than 2 * block_size # NOTE: block_size MUST be a power of 2 diff --git a/src/foreachindex.jl b/src/foreachindex.jl index c970aa7..afb2027 100644 --- a/src/foreachindex.jl +++ b/src/foreachindex.jl @@ -1,6 +1,14 @@ -@kernel cpu=false inbounds=true function _forindices_global!(f, indices) - i = @index(Global, Linear) - f(indices[i]) +@kernel inbounds=true cpu=false unsafe_indices=true function _forindices_global!(f, indices) + + # Calculate global index + N = @groupsize()[1] + iblock = @index(Group, Linear) + ithread = @index(Local, Linear) + i = ithread + (iblock - 0x1) * N + + if i <= length(indices) + f(indices[i]) + end end @@ -13,7 +21,8 @@ function _forindices_gpu( ) # GPU implementation @argcheck block_size > 0 - _forindices_global!(backend, block_size)(f, indices, ndrange=length(indices)) + blocks = (length(indices) + block_size - 1) ÷ block_size + _forindices_global!(backend, block_size)(f, indices, ndrange=(block_size * blocks,)) nothing end diff --git a/src/reduce/mapreduce_1d.jl b/src/reduce/mapreduce_1d.jl index ce623e9..dd43e55 100644 --- a/src/reduce/mapreduce_1d.jl +++ b/src/reduce/mapreduce_1d.jl @@ -1,4 +1,4 @@ -@kernel inbounds=true cpu=false function _mapreduce_block!(@Const(src), dst, f, op, neutral) +@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_block!(@Const(src), dst, f, op, neutral) @uniform N = @groupsize()[1] sdata = @localmem eltype(dst) (N,) diff --git a/src/reduce/mapreduce_nd.jl b/src/reduce/mapreduce_nd.jl index c5cf4e2..0a162d4 100644 --- a/src/reduce/mapreduce_nd.jl +++ b/src/reduce/mapreduce_nd.jl @@ -1,5 +1,11 @@ -@kernel inbounds=true cpu=false function _mapreduce_nd_by_thread!(@Const(src), dst, f, op, init, dims) - +@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_thread!( + @Const(src), + dst, + f, + op, + init, + dims, +) # One thread per output element, when there are more outer elements than in the reduced dim # e.g. reduce(+, rand(3, 1000), dims=1) => only 3 elements in the reduced dim src_sizes = size(src) @@ -64,8 +70,15 @@ end -@kernel inbounds=true cpu=false function _mapreduce_nd_by_block!(@Const(src), dst, f, op, init, neutral, dims) - +@kernel inbounds=true cpu=false unsafe_indices=true function _mapreduce_nd_by_block!( + @Const(src), + dst, + f, + op, + init, + neutral, + dims, +) # One block per output element, when there are more elements in the reduced dim than in outer # e.g. reduce(+, rand(3, 1000), dims=2) => only 3 elements in outer dimensions src_sizes = size(src) @@ -90,86 +103,84 @@ end iblock = @index(Group, Linear) - 0x1 ithread = @index(Local, Linear) - 0x1 - # Each block handles one output element - if iblock < output_size - - # # Sometimes slightly faster method using additional memory with - # # output_idx = @private typeof(iblock) (ndims,) - # tmp = iblock - # KernelAbstractions.Extras.@unroll for i in ndims:-1:1 - # output_idx[i] = tmp ÷ dst_strides[i] - # tmp = tmp % dst_strides[i] - # end - # # Compute the base index in src (excluding the reduced axis) - # input_base_idx = 0 - # KernelAbstractions.Extras.@unroll for i in 1:ndims - # i == dims && continue - # input_base_idx += output_idx[i] * src_strides[i] - # end - - # Compute the base index in src (excluding the reduced axis) - input_base_idx = typeof(ithread)(0) - tmp = iblock - KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16 - if i != dims - input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i] - end - tmp = tmp % dst_strides[i] + # Each block handles one output element - thus, iblock ∈ [0, output_size) + + # # Sometimes slightly faster method using additional memory with + # # output_idx = @private typeof(iblock) (ndims,) + # tmp = iblock + # KernelAbstractions.Extras.@unroll for i in ndims:-1:1 + # output_idx[i] = tmp ÷ dst_strides[i] + # tmp = tmp % dst_strides[i] + # end + # # Compute the base index in src (excluding the reduced axis) + # input_base_idx = 0 + # KernelAbstractions.Extras.@unroll for i in 1:ndims + # i == dims && continue + # input_base_idx += output_idx[i] * src_strides[i] + # end + + # Compute the base index in src (excluding the reduced axis) + input_base_idx = typeof(ithread)(0) + tmp = iblock + KernelAbstractions.Extras.@unroll for i in ndims:-1i16:1i16 + if i != dims + input_base_idx += (tmp ÷ dst_strides[i]) * src_strides[i] end + tmp = tmp % dst_strides[i] + end - # We have a block of threads to process the whole reduced dimension. First do pre-reduction - # in strides of N - partial = neutral - i = ithread - while i < reduce_size - src_idx = input_base_idx + i * src_strides[dims] - partial = op(partial, f(src[src_idx + 0x1])) - i += N - end + # We have a block of threads to process the whole reduced dimension. First do pre-reduction + # in strides of N + partial = neutral + i = ithread + while i < reduce_size + src_idx = input_base_idx + i * src_strides[dims] + partial = op(partial, f(src[src_idx + 0x1])) + i += N + end - # Store partial result in shared memory; now we are down to a single block to reduce within - sdata[ithread + 0x1] = partial - @synchronize() + # Store partial result in shared memory; now we are down to a single block to reduce within + sdata[ithread + 0x1] = partial + @synchronize() - if N >= 512u16 - ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])) - @synchronize() - end - if N >= 256u16 - ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])) - @synchronize() - end - if N >= 128u16 - ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])) - @synchronize() - end - if N >= 64u16 - ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])) - @synchronize() - end - if N >= 32u16 - ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])) - @synchronize() - end - if N >= 16u16 - ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])) - @synchronize() - end - if N >= 8u16 - ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])) - @synchronize() - end - if N >= 4u16 - ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])) - @synchronize() - end - if N >= 2u16 - ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])) - @synchronize() - end - if ithread == 0x0 - dst[iblock + 0x1] = op(init, sdata[0x1]) - end + if N >= 512u16 + ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])) + @synchronize() + end + if N >= 256u16 + ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])) + @synchronize() + end + if N >= 128u16 + ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])) + @synchronize() + end + if N >= 64u16 + ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])) + @synchronize() + end + if N >= 32u16 + ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])) + @synchronize() + end + if N >= 16u16 + ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])) + @synchronize() + end + if N >= 8u16 + ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])) + @synchronize() + end + if N >= 4u16 + ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])) + @synchronize() + end + if N >= 2u16 + ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])) + @synchronize() + end + if ithread == 0x0 + dst[iblock + 0x1] = op(init, sdata[0x1]) end end diff --git a/src/sort/merge_sort.jl b/src/sort/merge_sort.jl index 0dda224..8079f75 100644 --- a/src/sort/merge_sort.jl +++ b/src/sort/merge_sort.jl @@ -1,4 +1,4 @@ -@kernel inbounds=true function _merge_sort_block!(vec, comp) +@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_block!(vec, comp) @uniform N = @groupsize()[1] s_buf = @localmem eltype(vec) (N * 0x2,) @@ -75,8 +75,9 @@ end -@kernel inbounds=true function _merge_sort_global!(@Const(vec_in), vec_out, comp, half_size_group) - +@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_global!( + @Const(vec_in), vec_out, comp, half_size_group, +) len = length(vec_in) N = @groupsize()[1] diff --git a/src/sort/merge_sort_by_key.jl b/src/sort/merge_sort_by_key.jl index c3c8746..1afb680 100644 --- a/src/sort/merge_sort_by_key.jl +++ b/src/sort/merge_sort_by_key.jl @@ -1,4 +1,4 @@ -@kernel inbounds=true function _merge_sort_by_key_block!(keys, values, comp) +@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_by_key_block!(keys, values, comp) @uniform N = @groupsize()[1] s_keys = @localmem eltype(keys) (N * 0x2,) @@ -97,9 +97,11 @@ end -@kernel inbounds=true function _merge_sort_by_key_global!(@Const(keys_in), keys_out, - @Const(values_in), values_out, - comp, half_size_group) +@kernel inbounds=true cpu=false unsafe_indices=true function _merge_sort_by_key_global!( + @Const(keys_in), keys_out, + @Const(values_in), values_out, + comp, half_size_group, +) len = length(keys_in) N = @groupsize()[1]