Skip to content

Commit e1a110f

Browse files
committed
Implement groupreduce API
1 parent 8a87f77 commit e1a110f

File tree

4 files changed

+238
-38
lines changed

4 files changed

+238
-38
lines changed

src/KernelAbstractions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,8 @@ function __fake_compiler_job end
798798
# - LoopInfo
799799
###
800800

801+
include("reduce.jl")
802+
801803
include("extras/extras.jl")
802804

803805
include("reflection.jl")

src/reduce.jl

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
@groupreduce algo op val neutral [groupsize]
3+
4+
Perform group reduction of `val` using `op`.
5+
6+
# Arguments
7+
8+
- `algo` specifies which reduction algorithm to use:
9+
- `:thread`:
10+
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory).
11+
Available accross all backends.
12+
- `:warp`:
13+
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory).
14+
15+
- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.
16+
- `groupsize` specifies size of the workgroup.
17+
If a kernel does not specifies `groupsize` statically, then it is required to
18+
provide `groupsize`.
19+
Also can be used to perform reduction accross first `groupsize` threads
20+
(if `groupsize < @groupsize()`).
21+
22+
# Returns
23+
24+
Result of the reduction.
25+
"""
26+
macro groupreduce(algo, op, val, neutral)
27+
f = if algo.value == :thread
28+
__groupreduce
29+
elseif algo.value == :warp
30+
__warp_groupreduce
31+
else
32+
error(
33+
"@groupreduce supports only :thread or :warp as a reduction algorithm, " *
34+
"but $(algo.value) was specified.")
35+
end
36+
quote
37+
$f(
38+
$(esc(:__ctx__)),
39+
$(esc(op)),
40+
$(esc(val)),
41+
$(esc(neutral)),
42+
Val(prod($groupsize($(esc(:__ctx__))))),
43+
)
44+
end
45+
end
46+
47+
macro groupreduce(algo, op, val, neutral, groupsize)
48+
f = if algo.value == :thread
49+
__groupreduce
50+
elseif algo.value == :warp
51+
__warp_groupreduce
52+
else
53+
error(
54+
"@groupreduce supports only :thread or :warp as a reduction algorithm, " *
55+
"but $(algo.value) was specified.")
56+
end
57+
quote
58+
$f(
59+
$(esc(:__ctx__)),
60+
$(esc(op)),
61+
$(esc(val)),
62+
$(esc(neutral)),
63+
Val($(esc(groupsize))),
64+
)
65+
end
66+
end
67+
68+
function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize}
69+
storage = @localmem T groupsize
70+
71+
local_idx = @index(Local)
72+
local_idx groupsize && (storage[local_idx] = val)
73+
@synchronize()
74+
75+
s::UInt64 = groupsize ÷ 0x2
76+
while s > 0x0
77+
if (local_idx - 0x1) < s
78+
other_idx = local_idx + s
79+
if other_idx groupsize
80+
storage[local_idx] = op(storage[local_idx], storage[other_idx])
81+
end
82+
end
83+
@synchronize()
84+
s >>= 0x1
85+
end
86+
87+
if local_idx == 0x1
88+
val = storage[local_idx]
89+
end
90+
return val
91+
end
92+
93+
# Warp groupreduce.
94+
95+
macro shfl_down(val, offset)
96+
quote
97+
$__shfl_down($(esc(val)), $(esc(offset)))
98+
end
99+
end
100+
101+
# Backends should implement this.
102+
function __shfl_down end
103+
104+
@inline function __warp_reduce(val, op)
105+
offset::UInt32 = UInt32(32) ÷ 0x2
106+
while offset > 0x0
107+
val = op(val, @shfl_down(val, offset))
108+
offset >>= 0x1
109+
end
110+
return val
111+
end
112+
113+
# Assume warp is 32 lanes.
114+
const __warpsize::UInt32 = 32
115+
# Maximum number of warps (for a groupsize = 1024).
116+
const __warp_bins::UInt32 = 32
117+
118+
function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize}
119+
storage = @localmem T __warp_bins
120+
121+
local_idx = @index(Local)
122+
lane = (local_idx - 0x1) % __warpsize + 0x1
123+
warp_id = (local_idx - 0x1) ÷ __warpsize + 0x1
124+
125+
# Each warp performs a reduction and writes results into its own bin in `storage`.
126+
val = __warp_reduce(val, op)
127+
lane == 0x1 && (storage[warp_id] = val)
128+
@synchronize()
129+
130+
# Final reduction of the `storage` on the first warp.
131+
within_storage = (local_idx - 0x1) < groupsize ÷ __warpsize
132+
val = within_storage ? storage[lane] : neutral
133+
warp_id == 0x1 && (val = __warp_reduce(val, op))
134+
return val
135+
end

test/groupreduce.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
@kernel function groupreduce_thread_1!(y, x, op, neutral)
2+
i = @index(Global)
3+
val = i > length(x) ? neutral : x[i]
4+
res = KernelAbstractions.@groupreduce(:thread, op, val, neutral)
5+
i == 1 && (y[1] = res)
6+
end
7+
8+
@kernel function groupreduce_thread_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
9+
i = @index(Global)
10+
val = i > length(x) ? neutral : x[i]
11+
res = KernelAbstractions.@groupreduce(:thread, op, val, neutral, groupsize)
12+
i == 1 && (y[1] = res)
13+
end
14+
15+
@kernel function groupreduce_warp_1!(y, x, op, neutral)
16+
i = @index(Global)
17+
val = i > length(x) ? neutral : x[i]
18+
res = KernelAbstractions.@groupreduce(:warp, op, val, neutral)
19+
i == 1 && (y[1] = res)
20+
end
21+
22+
@kernel function groupreduce_warp_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
23+
i = @index(Global)
24+
val = i > length(x) ? neutral : x[i]
25+
res = KernelAbstractions.@groupreduce(:warp, op, val, neutral, groupsize)
26+
i == 1 && (y[1] = res)
27+
end
28+
29+
function groupreduce_testsuite(backend, AT)
30+
@testset "@groupreduce" begin
31+
@testset ":thread T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
32+
x = AT(ones(T, n))
33+
y = AT(zeros(T, 1))
34+
35+
groupreduce_thread_1!(backend(), n)(y, x, +, zero(T); ndrange=n)
36+
@test Array(y)[1] == n
37+
38+
groupreduce_thread_2!(backend())(y, x, +, zero(T), Val(128); ndrange=n)
39+
@test Array(y)[1] == 128
40+
41+
groupreduce_thread_2!(backend())(y, x, +, zero(T), Val(64); ndrange=n)
42+
@test Array(y)[1] == 64
43+
end
44+
45+
@testset ":warp T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
46+
x = AT(ones(T, n))
47+
y = AT(zeros(T, 1))
48+
groupreduce_warp_1!(backend(), n)(y, x, +, zero(T); ndrange=n)
49+
@test Array(y)[1] == n
50+
51+
groupreduce_warp_2!(backend())(y, x, +, zero(T), Val(128); ndrange=n)
52+
@test Array(y)[1] == 128
53+
54+
groupreduce_warp_2!(backend())(y, x, +, zero(T), Val(64); ndrange=n)
55+
@test Array(y)[1] == 64
56+
end
57+
end
58+
end

test/testsuite.jl

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,58 +38,63 @@ include("reflection.jl")
3838
include("examples.jl")
3939
include("convert.jl")
4040
include("specialfunctions.jl")
41+
include("groupreduce.jl")
4142

4243
function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}())
43-
@conditional_testset "Unittests" skip_tests begin
44-
unittest_testsuite(backend, backend_str, backend_mod, DAT; skip_tests)
45-
end
44+
# @conditional_testset "Unittests" skip_tests begin
45+
# unittest_testsuite(backend, backend_str, backend_mod, DAT; skip_tests)
46+
# end
4647

47-
@conditional_testset "SpecialFunctions" skip_tests begin
48-
specialfunctions_testsuite(backend)
49-
end
48+
# @conditional_testset "SpecialFunctions" skip_tests begin
49+
# specialfunctions_testsuite(backend)
50+
# end
5051

51-
@conditional_testset "Localmem" skip_tests begin
52-
localmem_testsuite(backend, AT)
53-
end
52+
# @conditional_testset "Localmem" skip_tests begin
53+
# localmem_testsuite(backend, AT)
54+
# end
5455

55-
@conditional_testset "Private" skip_tests begin
56-
private_testsuite(backend, AT)
57-
end
56+
# @conditional_testset "Private" skip_tests begin
57+
# private_testsuite(backend, AT)
58+
# end
5859

59-
@conditional_testset "Unroll" skip_tests begin
60-
unroll_testsuite(backend, AT)
61-
end
60+
# @conditional_testset "Unroll" skip_tests begin
61+
# unroll_testsuite(backend, AT)
62+
# end
6263

63-
@testset "NDIteration" begin
64-
nditeration_testsuite()
65-
end
64+
# @testset "NDIteration" begin
65+
# nditeration_testsuite()
66+
# end
6667

67-
@conditional_testset "copyto!" skip_tests begin
68-
copyto_testsuite(backend, AT)
69-
end
68+
# @conditional_testset "copyto!" skip_tests begin
69+
# copyto_testsuite(backend, AT)
70+
# end
7071

71-
@conditional_testset "Devices" skip_tests begin
72-
devices_testsuite(backend)
73-
end
72+
# @conditional_testset "Devices" skip_tests begin
73+
# devices_testsuite(backend)
74+
# end
7475

75-
@conditional_testset "Printing" skip_tests begin
76-
printing_testsuite(backend)
77-
end
76+
# @conditional_testset "Printing" skip_tests begin
77+
# printing_testsuite(backend)
78+
# end
7879

79-
@conditional_testset "Compiler" skip_tests begin
80-
compiler_testsuite(backend, AT)
81-
end
80+
# @conditional_testset "Compiler" skip_tests begin
81+
# compiler_testsuite(backend, AT)
82+
# end
8283

83-
@conditional_testset "Reflection" skip_tests begin
84-
reflection_testsuite(backend, backend_str, AT)
85-
end
84+
# @conditional_testset "Reflection" skip_tests begin
85+
# reflection_testsuite(backend, backend_str, AT)
86+
# end
8687

87-
@conditional_testset "Convert" skip_tests begin
88-
convert_testsuite(backend, AT)
89-
end
88+
# @conditional_testset "Convert" skip_tests begin
89+
# convert_testsuite(backend, AT)
90+
# end
91+
92+
# @conditional_testset "Examples" skip_tests begin
93+
# examples_testsuite(backend_str)
94+
# end
9095

91-
@conditional_testset "Examples" skip_tests begin
92-
examples_testsuite(backend_str)
96+
@conditional_testset "@groupreduce" skip_tests begin
97+
groupreduce_testsuite(backend, AT)
9398
end
9499

95100
return

0 commit comments

Comments
 (0)