diff --git a/src/AcceleratedKernels.jl b/src/AcceleratedKernels.jl index d4655de..d662c2a 100644 --- a/src/AcceleratedKernels.jl +++ b/src/AcceleratedKernels.jl @@ -14,6 +14,7 @@ module AcceleratedKernels using ArgCheck: @argcheck using GPUArraysCore: AnyGPUArray, @allowscalar using KernelAbstractions +using KernelAbstractions: @context import UnsafeAtomics diff --git a/src/reduce/mapreduce_1d_gpu.jl b/src/reduce/mapreduce_1d_gpu.jl index c1e31cc..39e7c41 100644 --- a/src/reduce/mapreduce_1d_gpu.jl +++ b/src/reduce/mapreduce_1d_gpu.jl @@ -25,60 +25,7 @@ @synchronize() - if N >= 512u16 - if ithread < 256u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]) - end - @synchronize() - end - if N >= 256u16 - if ithread < 128u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]) - end - @synchronize() - end - if N >= 128u16 - if ithread < 64u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]) - end - @synchronize() - end - if N >= 64u16 - if ithread < 32u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]) - end - @synchronize() - end - if N >= 32u16 - if ithread < 16u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]) - end - @synchronize() - end - if N >= 16u16 - if ithread < 8u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]) - end - @synchronize() - end - if N >= 8u16 - if ithread < 4u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]) - end - @synchronize() - end - if N >= 4u16 - if ithread < 2u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]) - end - @synchronize() - end - if N >= 2u16 - if ithread < 1u16 - sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 0x1 + 0x1]) - end - @synchronize() - end + @inline reduce_group!(@context, op, sdata, N, ithread) # Code below would work on NVidia GPUs with warp size of 32, but create race conditions and # return incorrect results on Intel Graphics. It would be useful to have a way to statically diff --git a/src/reduce/mapreduce_nd.jl b/src/reduce/mapreduce_nd.jl index cf7d825..ed815fe 100644 --- a/src/reduce/mapreduce_nd.jl +++ b/src/reduce/mapreduce_nd.jl @@ -332,42 +332,8 @@ end sdata[ithread + 0x1] = partial @synchronize() - if N >= 512u16 - ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])) - @synchronize() - end - if N >= 256u16 - ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])) - @synchronize() - end - if N >= 128u16 - ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])) - @synchronize() - end - if N >= 64u16 - ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])) - @synchronize() - end - if N >= 32u16 - ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])) - @synchronize() - end - if N >= 16u16 - ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])) - @synchronize() - end - if N >= 8u16 - ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])) - @synchronize() - end - if N >= 4u16 - ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])) - @synchronize() - end - if N >= 2u16 - ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])) - @synchronize() - end + @inline reduce_group!(@context, op, sdata, N, ithread) + if ithread == 0x0 dst[iblock + 0x1] = op(init, sdata[0x1]) end diff --git a/src/reduce/utilities.jl b/src/reduce/utilities.jl index 79a6436..48f387e 100644 --- a/src/reduce/utilities.jl +++ b/src/reduce/utilities.jl @@ -43,3 +43,60 @@ function _mapreduce_nd_apply_init!( dst[i] = op(init, f(src[i])) end end + +@inline function reduce_group!(@context, op, sdata, N, ithread) + if N >= 512u16 + if ithread < 256u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]) + end + @synchronize() + end + if N >= 256u16 + if ithread < 128u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]) + end + @synchronize() + end + if N >= 128u16 + if ithread < 64u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]) + end + @synchronize() + end + if N >= 64u16 + if ithread < 32u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]) + end + @synchronize() + end + if N >= 32u16 + if ithread < 16u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]) + end + @synchronize() + end + if N >= 16u16 + if ithread < 8u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]) + end + @synchronize() + end + if N >= 8u16 + if ithread < 4u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]) + end + @synchronize() + end + if N >= 4u16 + if ithread < 2u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]) + end + @synchronize() + end + if N >= 2u16 + if ithread < 1u16 + sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]) + end + @synchronize() + end +end