Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/AcceleratedKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module AcceleratedKernels
using ArgCheck: @argcheck
using GPUArraysCore: AnyGPUArray, @allowscalar
using KernelAbstractions
using KernelAbstractions: @context
import UnsafeAtomics


Expand Down
55 changes: 1 addition & 54 deletions src/reduce/mapreduce_1d_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,60 +25,7 @@

@synchronize()

if N >= 512u16
if ithread < 256u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])
end
@synchronize()
end
if N >= 256u16
if ithread < 128u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])
end
@synchronize()
end
if N >= 128u16
if ithread < 64u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])
end
@synchronize()
end
if N >= 64u16
if ithread < 32u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])
end
@synchronize()
end
if N >= 32u16
if ithread < 16u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])
end
@synchronize()
end
if N >= 16u16
if ithread < 8u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])
end
@synchronize()
end
if N >= 8u16
if ithread < 4u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])
end
@synchronize()
end
if N >= 4u16
if ithread < 2u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])
end
@synchronize()
end
if N >= 2u16
if ithread < 1u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 0x1 + 0x1])
end
@synchronize()
end
@inline reduce_group!(@context, op, sdata, N, ithread)

# Code below would work on NVidia GPUs with warp size of 32, but create race conditions and
# return incorrect results on Intel Graphics. It would be useful to have a way to statically
Expand Down
38 changes: 2 additions & 36 deletions src/reduce/mapreduce_nd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,42 +332,8 @@ end
sdata[ithread + 0x1] = partial
@synchronize()

if N >= 512u16
ithread < 256u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1]))
@synchronize()
end
if N >= 256u16
ithread < 128u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1]))
@synchronize()
end
if N >= 128u16
ithread < 64u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1]))
@synchronize()
end
if N >= 64u16
ithread < 32u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1]))
@synchronize()
end
if N >= 32u16
ithread < 16u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1]))
@synchronize()
end
if N >= 16u16
ithread < 8u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1]))
@synchronize()
end
if N >= 8u16
ithread < 4u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1]))
@synchronize()
end
if N >= 4u16
ithread < 2u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1]))
@synchronize()
end
if N >= 2u16
ithread < 1u16 && (sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1]))
@synchronize()
end
@inline reduce_group!(@context, op, sdata, N, ithread)

if ithread == 0x0
dst[iblock + 0x1] = op(init, sdata[0x1])
end
Expand Down
57 changes: 57 additions & 0 deletions src/reduce/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,60 @@ function _mapreduce_nd_apply_init!(
dst[i] = op(init, f(src[i]))
end
end

@inline function reduce_group!(@context, op, sdata, N, ithread)
if N >= 512u16
if ithread < 256u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 256u16 + 0x1])
end
@synchronize()
end
if N >= 256u16
if ithread < 128u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 128u16 + 0x1])
end
@synchronize()
end
if N >= 128u16
if ithread < 64u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 64u16 + 0x1])
end
@synchronize()
end
if N >= 64u16
if ithread < 32u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 32u16 + 0x1])
end
@synchronize()
end
if N >= 32u16
if ithread < 16u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 16u16 + 0x1])
end
@synchronize()
end
if N >= 16u16
if ithread < 8u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 8u16 + 0x1])
end
@synchronize()
end
if N >= 8u16
if ithread < 4u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 4u16 + 0x1])
end
@synchronize()
end
if N >= 4u16
if ithread < 2u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 2u16 + 0x1])
end
@synchronize()
end
if N >= 2u16
if ithread < 1u16
sdata[ithread + 0x1] = op(sdata[ithread + 0x1], sdata[ithread + 1u16 + 0x1])
end
@synchronize()
end
end
Loading