Skip to content

Commit cbc8bd5

Browse files
committed
Auto-select reduction algorithm & remove at-shfl_down macro
1 parent a647992 commit cbc8bd5

File tree

2 files changed

+55
-63
lines changed

2 files changed

+55
-63
lines changed

src/reduce.jl

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
1-
export @groupreduce, Reduction
1+
export @groupreduce
22

33
module Reduction
44
const thread = Val(:thread)
55
const warp = Val(:warp)
66
end
77

88
"""
9-
@groupreduce op val neutral algo [groupsize]
9+
@groupreduce op val neutral [groupsize]
1010
1111
Perform group reduction of `val` using `op`.
12+
If backend supports warp reduction, it will use it instead of thread reduction.
1213
1314
# Arguments
1415
15-
- `algo` specifies which reduction algorithm to use:
16-
- `Reduction.thread`:
17-
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory).
18-
Available accross all backends.
19-
- `Reduction.warp`:
20-
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory).
21-
Potentially faster, since requires fewer writes to shared memory.
22-
To query if backend supports warp reduction, use `supports_warp_reduction(backend)`.
23-
2416
- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.
2517
2618
- `groupsize` specifies size of the workgroup.
@@ -33,29 +25,51 @@ Perform group reduction of `val` using `op`.
3325
3426
Result of the reduction.
3527
"""
36-
macro groupreduce(op, val, neutral, algo)
28+
macro groupreduce(op, val, neutral)
3729
return quote
38-
__groupreduce(
39-
$(esc(:__ctx__)),
40-
$(esc(op)),
41-
$(esc(val)),
42-
$(esc(neutral)),
43-
Val(prod($groupsize($(esc(:__ctx__))))),
44-
$(esc(algo)),
45-
)
30+
if __supports_warp_reduction()
31+
__groupreduce(
32+
$(esc(:__ctx__)),
33+
$(esc(op)),
34+
$(esc(val)),
35+
$(esc(neutral)),
36+
Val(prod($groupsize($(esc(:__ctx__))))),
37+
$(esc(Reduction.warp)),
38+
)
39+
else
40+
__groupreduce(
41+
$(esc(:__ctx__)),
42+
$(esc(op)),
43+
$(esc(val)),
44+
$(esc(neutral)),
45+
Val(prod($groupsize($(esc(:__ctx__))))),
46+
$(esc(Reduction.thread)),
47+
)
48+
end
4649
end
4750
end
4851

49-
macro groupreduce(op, val, neutral, algo, groupsize)
52+
macro groupreduce(op, val, neutral, groupsize)
5053
return quote
51-
__groupreduce(
52-
$(esc(:__ctx__)),
53-
$(esc(op)),
54-
$(esc(val)),
55-
$(esc(neutral)),
56-
Val($(esc(groupsize))),
57-
$(esc(algo)),
58-
)
54+
if __supports_warp_reduction()
55+
__groupreduce(
56+
$(esc(:__ctx__)),
57+
$(esc(op)),
58+
$(esc(val)),
59+
$(esc(neutral)),
60+
Val($(esc(groupsize))),
61+
$(esc(Reduction.warp)),
62+
)
63+
else
64+
__groupreduce(
65+
$(esc(:__ctx__)),
66+
$(esc(op)),
67+
$(esc(val)),
68+
$(esc(neutral)),
69+
Val($(esc(groupsize))),
70+
$(esc(Reduction.thread)),
71+
)
72+
end
5973
end
6074
end
6175

@@ -86,15 +100,9 @@ end
86100

87101
# Warp groupreduce.
88102

89-
macro shfl_down(val, offset)
90-
return quote
91-
$__shfl_down($(esc(val)), $(esc(offset)))
92-
end
93-
end
94-
95-
# Backends should implement these two.
103+
# NOTE: Backends should implement these two device functions (with `@device_override`).
96104
function __shfl_down end
97-
supports_warp_reduction(::Backend) = false
105+
function __supports_warp_reduction() end
98106

99107
# Assume warp is 32 lanes.
100108
const __warpsize = UInt32(32)
@@ -104,7 +112,7 @@ const __warp_bins = UInt32(32)
104112
@inline function __warp_reduce(val, op)
105113
offset::UInt32 = __warpsize ÷ 0x02
106114
while offset > 0x00
107-
val = op(val, @shfl_down(val, offset))
115+
val = op(val, __shfl_down(val, offset))
108116
offset >>= 0x01
109117
end
110118
return val

test/groupreduce.jl

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,35 @@
1-
@kernel function groupreduce_1!(y, x, op, neutral, algo)
1+
@kernel cpu=false function groupreduce_1!(y, x, op, neutral)
22
i = @index(Global)
33
val = i > length(x) ? neutral : x[i]
4-
res = @groupreduce(op, val, neutral, algo)
4+
res = @groupreduce(op, val, neutral)
55
i == 1 && (y[1] = res)
66
end
77

8-
@kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize}
8+
@kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
99
i = @index(Global)
1010
val = i > length(x) ? neutral : x[i]
11-
res = @groupreduce(op, val, neutral, algo, groupsize)
11+
res = @groupreduce(op, val, neutral, groupsize)
1212
i == 1 && (y[1] = res)
1313
end
1414

1515
function groupreduce_testsuite(backend, AT)
16-
# TODO should be better way of querying max groupsize
16+
# TODO should be a better way of querying max groupsize
1717
groupsizes = "$backend" == "oneAPIBackend" ?
1818
(256,) :
1919
(256, 512, 1024)
2020
@testset "@groupreduce" begin
21-
@testset "thread reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in groupsizes
21+
@testset "T=$T, n=$n" for T in (Float16, Float32, Float64, Int16, Int32, Int64), n in groupsizes
2222
x = AT(ones(T, n))
2323
y = AT(zeros(T, 1))
2424

25-
groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.thread; ndrange = n)
25+
groupreduce_1!(backend(), n)(y, x, +, zero(T); ndrange = n)
2626
@test Array(y)[1] == n
2727

28-
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(128); ndrange = n)
28+
groupreduce_2!(backend())(y, x, +, zero(T), Val(128); ndrange = n)
2929
@test Array(y)[1] == 128
3030

31-
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(64); ndrange = n)
31+
groupreduce_2!(backend())(y, x, +, zero(T), Val(64); ndrange = n)
3232
@test Array(y)[1] == 64
3333
end
34-
35-
warp_reduction = KernelAbstractions.supports_warp_reduction(backend())
36-
if warp_reduction
37-
@testset "warp reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in groupsizes
38-
x = AT(ones(T, n))
39-
y = AT(zeros(T, 1))
40-
groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.warp; ndrange = n)
41-
@test Array(y)[1] == n
42-
43-
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(128); ndrange = n)
44-
@test Array(y)[1] == 128
45-
46-
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(64); ndrange = n)
47-
@test Array(y)[1] == 64
48-
end
49-
end
5034
end
5135
end

0 commit comments

Comments
 (0)