Skip to content

Implement groupreduce API #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
@uniform
@groupsize
@ndrange
synchronize
allocate
@groupreduce
```

## Host language

```@docs
synchronize
allocate
KernelAbstractions.zeros
```

Expand Down
2 changes: 2 additions & 0 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,8 @@ function __fake_compiler_job end
# - LoopInfo
###

include("reduce.jl")

include("extras/extras.jl")

include("reflection.jl")
Expand Down
130 changes: 130 additions & 0 deletions src/reduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
export @groupreduce, Reduction

module Reduction
const thread = Val(:thread)
const warp = Val(:warp)
end

"""
@groupreduce op val neutral algo [groupsize]

Perform group reduction of `val` using `op`.

# Arguments

- `algo` specifies which reduction algorithm to use:
- `Reduction.thread`:
Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory).
Available accross all backends.
- `Reduction.warp`:
Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory).
Potentially faster, since requires fewer writes to shared memory.
To query if backend supports warp reduction, use `supports_warp_reduction(backend)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that needed? Shouldn't the backend go and use warp reductions if it can?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now doing an auto-selection of the algorithm based on device function __supports_warp_reduction().


- `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`.

- `groupsize` specifies size of the workgroup.
If a kernel does not specifies `groupsize` statically, then it is required to
provide `groupsize`.
Also can be used to perform reduction accross first `groupsize` threads
(if `groupsize < @groupsize()`).

# Returns

Result of the reduction.
"""
macro groupreduce(op, val, neutral, algo)
quote
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val(prod($groupsize($(esc(:__ctx__))))),
$(esc(algo)),
)
end
end

macro groupreduce(op, val, neutral, algo, groupsize)
quote
__groupreduce(
$(esc(:__ctx__)),
$(esc(op)),
$(esc(val)),
$(esc(neutral)),
Val($(esc(groupsize))),
$(esc(algo)),
)
end
end

function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:thread}) where {T, groupsize}
storage = @localmem T groupsize

local_idx = @index(Local)
@inbounds local_idx ≤ groupsize && (storage[local_idx] = val)
@synchronize()

s::UInt64 = groupsize ÷ 0x2
while s > 0x0
if (local_idx - 0x1) < s
other_idx = local_idx + s
if other_idx ≤ groupsize
@inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx])
end
end
@synchronize()
s >>= 0x1
end

if local_idx == 0x1
@inbounds val = storage[local_idx]
end
return val
end

# Warp groupreduce.

macro shfl_down(val, offset)
quote
$__shfl_down($(esc(val)), $(esc(offset)))
end
end

# Backends should implement these two.
function __shfl_down end
supports_warp_reduction(::CPU) = false

@inline function __warp_reduce(val, op)
offset::UInt32 = UInt32(32) ÷ 0x2
while offset > 0x0
val = op(val, @shfl_down(val, offset))
offset >>= 0x1
end
return val
end

# Assume warp is 32 lanes.
const __warpsize::UInt32 = 32
# Maximum number of warps (for a groupsize = 1024).
const __warp_bins::UInt32 = 32

function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:warp}) where {T, groupsize}
storage = @localmem T __warp_bins

local_idx = @index(Local)
lane = (local_idx - 0x1) % __warpsize + 0x1
warp_id = (local_idx - 0x1) ÷ __warpsize + 0x1

# Each warp performs a reduction and writes results into its own bin in `storage`.
val = __warp_reduce(val, op)
@inbounds lane == 0x1 && (storage[warp_id] = val)
@synchronize()

# Final reduction of the `storage` on the first warp.
within_storage = (local_idx - 0x1) < groupsize ÷ __warpsize
@inbounds val = within_storage ? storage[lane] : neutral
warp_id == 0x1 && (val = __warp_reduce(val, op))
return val
end
48 changes: 48 additions & 0 deletions test/groupreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
@kernel function groupreduce_1!(y, x, op, neutral, algo)
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo)
i == 1 && (y[1] = res)
end

@kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize}
i = @index(Global)
val = i > length(x) ? neutral : x[i]
res = @groupreduce(op, val, neutral, algo, groupsize)
i == 1 && (y[1] = res)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need to be cpu=false since you are using non-top-level @synchronize

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


function groupreduce_testsuite(backend, AT)
@testset "@groupreduce" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@testset "@groupreduce" begin
return @testset "@groupreduce" begin

@testset "thread reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)
x = AT(ones(T, n))
y = AT(zeros(T, 1))

groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.thread; ndrange=n)
@test Array(y)[1] == n

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(128); ndrange=n)
@test Array(y)[1] == 128

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(64); ndrange=n)
@test Array(y)[1] == 64
end

warp_reduction = KernelAbstractions.supports_warp_reduction(backend())
if warp_reduction
@testset "warp reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024)

x = AT(ones(T, n))
y = AT(zeros(T, 1))
groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.warp; ndrange=n)
@test Array(y)[1] == n

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(128); ndrange=n)
@test Array(y)[1] == 128

groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(64); ndrange=n)
@test Array(y)[1] == 64
end
end
end
end
5 changes: 5 additions & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ include("reflection.jl")
include("examples.jl")
include("convert.jl")
include("specialfunctions.jl")
include("groupreduce.jl")

function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}())
@conditional_testset "Unittests" skip_tests begin
Expand Down Expand Up @@ -92,6 +93,10 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{
examples_testsuite(backend_str)
end

@conditional_testset "@groupreduce" skip_tests begin
groupreduce_testsuite(backend, AT)
end

return
end

Expand Down
Loading