Skip to content

Commit ff4097f

Browse files
committed
Simplify algo selection
1 parent e1a110f commit ff4097f

File tree

4 files changed

+91
-105
lines changed

4 files changed

+91
-105
lines changed

docs/src/api.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
@uniform
1414
@groupsize
1515
@ndrange
16-
synchronize
17-
allocate
16+
@groupreduce
1817
```
1918

2019
## Host language
2120

2221
```@docs
22+
synchronize
23+
allocate
2324
KernelAbstractions.zeros
2425
```
2526

src/reduce.jl

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
1+
export @groupreduce, Reduction
2+
3+
module Reduction
4+
const thread = Val(:thread)
5+
const warp = Val(:warp)
6+
end
7+
18
"""
2-
@groupreduce algo op val neutral [groupsize]
9+
@groupreduce op val neutral algo [groupsize]
310
411
Perform group reduction of `val` using `op`.
512
613
# Arguments
714
815
- `algo` specifies which reduction algorithm to use:
9-
- `:thread`:
16+
- `Reduction.thread`:
1017
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory).
1118
Available accross all backends.
12-
- `:warp`:
19+
- `Reduction.warp`:
1320
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory).
21+
Potentially faster, since requires fewer writes to shared memory.
22+
To query if backend supports warp reduction, use `supports_warp_reduction(backend)`.
1423
1524
- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.
25+
1626
- `groupsize` specifies size of the workgroup.
1727
If a kernel does not specifies `groupsize` statically, then it is required to
1828
provide `groupsize`.
@@ -23,69 +33,53 @@ Perform group reduction of `val` using `op`.
2333
2434
Result of the reduction.
2535
"""
26-
macro groupreduce(algo, op, val, neutral)
27-
f = if algo.value == :thread
28-
__groupreduce
29-
elseif algo.value == :warp
30-
__warp_groupreduce
31-
else
32-
error(
33-
"@groupreduce supports only :thread or :warp as a reduction algorithm, " *
34-
"but $(algo.value) was specified.")
35-
end
36+
macro groupreduce(op, val, neutral, algo)
3637
quote
37-
$f(
38+
__groupreduce(
3839
$(esc(:__ctx__)),
3940
$(esc(op)),
4041
$(esc(val)),
4142
$(esc(neutral)),
4243
Val(prod($groupsize($(esc(:__ctx__))))),
44+
$(esc(algo)),
4345
)
4446
end
4547
end
4648

47-
macro groupreduce(algo, op, val, neutral, groupsize)
48-
f = if algo.value == :thread
49-
__groupreduce
50-
elseif algo.value == :warp
51-
__warp_groupreduce
52-
else
53-
error(
54-
"@groupreduce supports only :thread or :warp as a reduction algorithm, " *
55-
"but $(algo.value) was specified.")
56-
end
49+
macro groupreduce(op, val, neutral, algo, groupsize)
5750
quote
58-
$f(
51+
__groupreduce(
5952
$(esc(:__ctx__)),
6053
$(esc(op)),
6154
$(esc(val)),
6255
$(esc(neutral)),
6356
Val($(esc(groupsize))),
57+
$(esc(algo)),
6458
)
6559
end
6660
end
6761

68-
function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize}
62+
function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:thread}) where {T, groupsize}
6963
storage = @localmem T groupsize
7064

7165
local_idx = @index(Local)
72-
local_idx groupsize && (storage[local_idx] = val)
66+
@inbounds local_idx groupsize && (storage[local_idx] = val)
7367
@synchronize()
7468

7569
s::UInt64 = groupsize ÷ 0x2
7670
while s > 0x0
7771
if (local_idx - 0x1) < s
7872
other_idx = local_idx + s
7973
if other_idx groupsize
80-
storage[local_idx] = op(storage[local_idx], storage[other_idx])
74+
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
8175
end
8276
end
8377
@synchronize()
8478
s >>= 0x1
8579
end
8680

8781
if local_idx == 0x1
88-
val = storage[local_idx]
82+
@inbounds val = storage[local_idx]
8983
end
9084
return val
9185
end
@@ -98,8 +92,9 @@ macro shfl_down(val, offset)
9892
end
9993
end
10094

101-
# Backends should implement this.
95+
# Backends should implement these two.
10296
function __shfl_down end
97+
supports_warp_reduction(::CPU) = false
10398

10499
@inline function __warp_reduce(val, op)
105100
offset::UInt32 = UInt32(32) ÷ 0x2
@@ -115,7 +110,7 @@ const __warpsize::UInt32 = 32
115110
# Maximum number of warps (for a groupsize = 1024).
116111
const __warp_bins::UInt32 = 32
117112

118-
function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) where {T, groupsize}
113+
function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:warp}) where {T, groupsize}
119114
storage = @localmem T __warp_bins
120115

121116
local_idx = @index(Local)
@@ -124,12 +119,12 @@ function __warp_groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}) w
124119

125120
# Each warp performs a reduction and writes results into its own bin in `storage`.
126121
val = __warp_reduce(val, op)
127-
lane == 0x1 && (storage[warp_id] = val)
122+
@inbounds lane == 0x1 && (storage[warp_id] = val)
128123
@synchronize()
129124

130125
# Final reduction of the `storage` on the first warp.
131126
within_storage = (local_idx - 0x1) < groupsize ÷ __warpsize
132-
val = within_storage ? storage[lane] : neutral
127+
@inbounds val = within_storage ? storage[lane] : neutral
133128
warp_id == 0x1 && (val = __warp_reduce(val, op))
134129
return val
135130
end

test/groupreduce.jl

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,48 @@
1-
@kernel function groupreduce_thread_1!(y, x, op, neutral)
1+
@kernel function groupreduce_1!(y, x, op, neutral, algo)
22
i = @index(Global)
33
val = i > length(x) ? neutral : x[i]
4-
res = KernelAbstractions.@groupreduce(:thread, op, val, neutral)
4+
res = @groupreduce(op, val, neutral, algo)
55
i == 1 && (y[1] = res)
66
end
77

8-
@kernel function groupreduce_thread_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
8+
@kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize}
99
i = @index(Global)
1010
val = i > length(x) ? neutral : x[i]
11-
res = KernelAbstractions.@groupreduce(:thread, op, val, neutral, groupsize)
12-
i == 1 && (y[1] = res)
13-
end
14-
15-
@kernel function groupreduce_warp_1!(y, x, op, neutral)
16-
i = @index(Global)
17-
val = i > length(x) ? neutral : x[i]
18-
res = KernelAbstractions.@groupreduce(:warp, op, val, neutral)
19-
i == 1 && (y[1] = res)
20-
end
21-
22-
@kernel function groupreduce_warp_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize}
23-
i = @index(Global)
24-
val = i > length(x) ? neutral : x[i]
25-
res = KernelAbstractions.@groupreduce(:warp, op, val, neutral, groupsize)
11+
res = @groupreduce(op, val, neutral, algo, groupsize)
2612
i == 1 && (y[1] = res)
2713
end
2814

2915
function groupreduce_testsuite(backend, AT)
3016
@testset "@groupreduce" begin
31-
@testset ":thread T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
17+
@testset "thread reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
3218
x = AT(ones(T, n))
3319
y = AT(zeros(T, 1))
3420

35-
groupreduce_thread_1!(backend(), n)(y, x, +, zero(T); ndrange=n)
21+
groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.thread; ndrange=n)
3622
@test Array(y)[1] == n
3723

38-
groupreduce_thread_2!(backend())(y, x, +, zero(T), Val(128); ndrange=n)
24+
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(128); ndrange=n)
3925
@test Array(y)[1] == 128
4026

41-
groupreduce_thread_2!(backend())(y, x, +, zero(T), Val(64); ndrange=n)
27+
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(64); ndrange=n)
4228
@test Array(y)[1] == 64
4329
end
4430

45-
@testset ":warp T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
46-
x = AT(ones(T, n))
47-
y = AT(zeros(T, 1))
48-
groupreduce_warp_1!(backend(), n)(y, x, +, zero(T); ndrange=n)
49-
@test Array(y)[1] == n
31+
warp_reduction = KernelAbstractions.supports_warp_reduction(backend())
32+
if warp_reduction
33+
@testset "warp reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
5034

51-
groupreduce_warp_2!(backend())(y, x, +, zero(T), Val(128); ndrange=n)
52-
@test Array(y)[1] == 128
35+
x = AT(ones(T, n))
36+
y = AT(zeros(T, 1))
37+
groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.warp; ndrange=n)
38+
@test Array(y)[1] == n
5339

54-
groupreduce_warp_2!(backend())(y, x, +, zero(T), Val(64); ndrange=n)
55-
@test Array(y)[1] == 64
40+
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(128); ndrange=n)
41+
@test Array(y)[1] == 128
42+
43+
groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(64); ndrange=n)
44+
@test Array(y)[1] == 64
45+
end
5646
end
5747
end
5848
end

test/testsuite.jl

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -41,57 +41,57 @@ include("specialfunctions.jl")
4141
include("groupreduce.jl")
4242

4343
function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}())
44-
# @conditional_testset "Unittests" skip_tests begin
45-
# unittest_testsuite(backend, backend_str, backend_mod, DAT; skip_tests)
46-
# end
44+
@conditional_testset "Unittests" skip_tests begin
45+
unittest_testsuite(backend, backend_str, backend_mod, DAT; skip_tests)
46+
end
4747

48-
# @conditional_testset "SpecialFunctions" skip_tests begin
49-
# specialfunctions_testsuite(backend)
50-
# end
48+
@conditional_testset "SpecialFunctions" skip_tests begin
49+
specialfunctions_testsuite(backend)
50+
end
5151

52-
# @conditional_testset "Localmem" skip_tests begin
53-
# localmem_testsuite(backend, AT)
54-
# end
52+
@conditional_testset "Localmem" skip_tests begin
53+
localmem_testsuite(backend, AT)
54+
end
5555

56-
# @conditional_testset "Private" skip_tests begin
57-
# private_testsuite(backend, AT)
58-
# end
56+
@conditional_testset "Private" skip_tests begin
57+
private_testsuite(backend, AT)
58+
end
5959

60-
# @conditional_testset "Unroll" skip_tests begin
61-
# unroll_testsuite(backend, AT)
62-
# end
60+
@conditional_testset "Unroll" skip_tests begin
61+
unroll_testsuite(backend, AT)
62+
end
6363

64-
# @testset "NDIteration" begin
65-
# nditeration_testsuite()
66-
# end
64+
@testset "NDIteration" begin
65+
nditeration_testsuite()
66+
end
6767

68-
# @conditional_testset "copyto!" skip_tests begin
69-
# copyto_testsuite(backend, AT)
70-
# end
68+
@conditional_testset "copyto!" skip_tests begin
69+
copyto_testsuite(backend, AT)
70+
end
7171

72-
# @conditional_testset "Devices" skip_tests begin
73-
# devices_testsuite(backend)
74-
# end
72+
@conditional_testset "Devices" skip_tests begin
73+
devices_testsuite(backend)
74+
end
7575

76-
# @conditional_testset "Printing" skip_tests begin
77-
# printing_testsuite(backend)
78-
# end
76+
@conditional_testset "Printing" skip_tests begin
77+
printing_testsuite(backend)
78+
end
7979

80-
# @conditional_testset "Compiler" skip_tests begin
81-
# compiler_testsuite(backend, AT)
82-
# end
80+
@conditional_testset "Compiler" skip_tests begin
81+
compiler_testsuite(backend, AT)
82+
end
8383

84-
# @conditional_testset "Reflection" skip_tests begin
85-
# reflection_testsuite(backend, backend_str, AT)
86-
# end
84+
@conditional_testset "Reflection" skip_tests begin
85+
reflection_testsuite(backend, backend_str, AT)
86+
end
8787

88-
# @conditional_testset "Convert" skip_tests begin
89-
# convert_testsuite(backend, AT)
90-
# end
88+
@conditional_testset "Convert" skip_tests begin
89+
convert_testsuite(backend, AT)
90+
end
9191

92-
# @conditional_testset "Examples" skip_tests begin
93-
# examples_testsuite(backend_str)
94-
# end
92+
@conditional_testset "Examples" skip_tests begin
93+
examples_testsuite(backend_str)
94+
end
9595

9696
@conditional_testset "@groupreduce" skip_tests begin
9797
groupreduce_testsuite(backend, AT)

0 commit comments

Comments
 (0)