- 
                Notifications
    You must be signed in to change notification settings 
- Fork 79
Implement groupreduce API #559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
e1a110f
              ff4097f
              6a35eb8
              224e8c8
              4a8e707
              7c923fb
              a647992
              cbc8bd5
              bb77270
              db5abc5
              618c840
              344d484
              1cd2d2f
              7c96e5a
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| export @groupreduce, Reduction | ||
|  | ||
| module Reduction | ||
| const thread = Val(:thread) | ||
| const warp = Val(:warp) | ||
| end | ||
|  | ||
| """ | ||
| @groupreduce op val neutral algo [groupsize] | ||
|  | ||
| Perform group reduction of `val` using `op`. | ||
|  | ||
| # Arguments | ||
|  | ||
| - `algo` specifies which reduction algorithm to use: | ||
| - `Reduction.thread`: | ||
| Perform thread group reduction (requires `groupsize * sizeof(T)` bytes of shared memory). | ||
| Available accross all backends. | ||
| - `Reduction.warp`: | ||
| Perform warp group reduction (requires `32 * sizeof(T)` bytes of shared memory). | ||
| Potentially faster, since requires fewer writes to shared memory. | ||
| To query if backend supports warp reduction, use `supports_warp_reduction(backend)`. | ||
|  | ||
| - `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`. | ||
|  | ||
| - `groupsize` specifies size of the workgroup. | ||
| If a kernel does not specifies `groupsize` statically, then it is required to | ||
| provide `groupsize`. | ||
| Also can be used to perform reduction accross first `groupsize` threads | ||
| (if `groupsize < @groupsize()`). | ||
|  | ||
| # Returns | ||
|  | ||
| Result of the reduction. | ||
| """ | ||
| macro groupreduce(op, val, neutral, algo) | ||
| quote | ||
|         
                  pxl-th marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| __groupreduce( | ||
| $(esc(:__ctx__)), | ||
| $(esc(op)), | ||
| $(esc(val)), | ||
| $(esc(neutral)), | ||
| Val(prod($groupsize($(esc(:__ctx__))))), | ||
| $(esc(algo)), | ||
| ) | ||
| end | ||
| end | ||
|  | ||
| macro groupreduce(op, val, neutral, algo, groupsize) | ||
| quote | ||
|         
                  pxl-th marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| __groupreduce( | ||
| $(esc(:__ctx__)), | ||
| $(esc(op)), | ||
| $(esc(val)), | ||
| $(esc(neutral)), | ||
| Val($(esc(groupsize))), | ||
| $(esc(algo)), | ||
| ) | ||
| end | ||
| end | ||
|  | ||
| function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:thread}) where {T, groupsize} | ||
| storage = @localmem T groupsize | ||
|  | ||
| local_idx = @index(Local) | ||
| @inbounds local_idx ≤ groupsize && (storage[local_idx] = val) | ||
| @synchronize() | ||
|  | ||
| s::UInt64 = groupsize ÷ 0x2 | ||
| while s > 0x0 | ||
| if (local_idx - 0x1) < s | ||
| other_idx = local_idx + s | ||
| if other_idx ≤ groupsize | ||
|         
                  pxl-th marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| @inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx]) | ||
| end | ||
| end | ||
| @synchronize() | ||
| s >>= 0x1 | ||
| end | ||
|  | ||
| if local_idx == 0x1 | ||
| @inbounds val = storage[local_idx] | ||
| end | ||
| return val | ||
| end | ||
|  | ||
| # Warp groupreduce. | ||
|  | ||
| macro shfl_down(val, offset) | ||
| quote | ||
|         
                  pxl-th marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| $__shfl_down($(esc(val)), $(esc(offset))) | ||
| end | ||
| end | ||
|  | ||
| # Backends should implement these two. | ||
| function __shfl_down end | ||
| supports_warp_reduction(::CPU) = false | ||
|  | ||
| @inline function __warp_reduce(val, op) | ||
| offset::UInt32 = UInt32(32) ÷ 0x2 | ||
| while offset > 0x0 | ||
| val = op(val, @shfl_down(val, offset)) | ||
|         
                  pxl-th marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| offset >>= 0x1 | ||
| end | ||
| return val | ||
| end | ||
|  | ||
| # Assume warp is 32 lanes. | ||
| const __warpsize::UInt32 = 32 | ||
| # Maximum number of warps (for a groupsize = 1024). | ||
| const __warp_bins::UInt32 = 32 | ||
|  | ||
| function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:warp}) where {T, groupsize} | ||
| storage = @localmem T __warp_bins | ||
|  | ||
| local_idx = @index(Local) | ||
| lane = (local_idx - 0x1) % __warpsize + 0x1 | ||
| warp_id = (local_idx - 0x1) ÷ __warpsize + 0x1 | ||
|  | ||
| # Each warp performs a reduction and writes results into its own bin in `storage`. | ||
| val = __warp_reduce(val, op) | ||
|         
                  pxl-th marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| @inbounds lane == 0x1 && (storage[warp_id] = val) | ||
| @synchronize() | ||
|  | ||
| # Final reduction of the `storage` on the first warp. | ||
| within_storage = (local_idx - 0x1) < groupsize ÷ __warpsize | ||
| @inbounds val = within_storage ? storage[lane] : neutral | ||
| warp_id == 0x1 && (val = __warp_reduce(val, op)) | ||
| return val | ||
| end | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,48 @@ | ||||||||||||||||||||||
| @kernel function groupreduce_1!(y, x, op, neutral, algo) | ||||||||||||||||||||||
| i = @index(Global) | ||||||||||||||||||||||
| val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||
| res = @groupreduce(op, val, neutral, algo) | ||||||||||||||||||||||
| i == 1 && (y[1] = res) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| @kernel function groupreduce_2!(y, x, op, neutral, algo, ::Val{groupsize}) where {groupsize} | ||||||||||||||||||||||
| i = @index(Global) | ||||||||||||||||||||||
| val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||
| res = @groupreduce(op, val, neutral, algo, groupsize) | ||||||||||||||||||||||
| i == 1 && (y[1] = res) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|          | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| function groupreduce_testsuite(backend, AT) | ||||||||||||||||||||||
| @testset "@groupreduce" begin | ||||||||||||||||||||||
|         
                  pxl-th marked this conversation as resolved.
              Show resolved
            Hide resolved         
                  pxl-th marked this conversation as resolved.
              Show resolved
            Hide resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
         
                  pxl-th marked this conversation as resolved.
              Show resolved
            Hide resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 | ||||||||||||||||||||||
| @testset "thread reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024) | ||||||||||||||||||||||
| x = AT(ones(T, n)) | ||||||||||||||||||||||
| y = AT(zeros(T, 1)) | ||||||||||||||||||||||
|         
                  pxl-th marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.thread; ndrange=n) | ||||||||||||||||||||||
| @test Array(y)[1] == n | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(128); ndrange=n) | ||||||||||||||||||||||
| @test Array(y)[1] == 128 | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| groupreduce_2!(backend())(y, x, +, zero(T), Reduction.thread, Val(64); ndrange=n) | ||||||||||||||||||||||
| @test Array(y)[1] == 64 | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| warp_reduction = KernelAbstractions.supports_warp_reduction(backend()) | ||||||||||||||||||||||
| if warp_reduction | ||||||||||||||||||||||
| @testset "warp reduction T=$T, n=$n" for T in (Float16, Float32, Int32, Int64), n in (256, 512, 1024) | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| x = AT(ones(T, n)) | ||||||||||||||||||||||
| y = AT(zeros(T, 1)) | ||||||||||||||||||||||
| groupreduce_1!(backend(), n)(y, x, +, zero(T), Reduction.warp; ndrange=n) | ||||||||||||||||||||||
| @test Array(y)[1] == n | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
|         
                  pxl-th marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||||||||||||||||||||
| groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(128); ndrange=n) | ||||||||||||||||||||||
| @test Array(y)[1] == 128 | ||||||||||||||||||||||
|  | ||||||||||||||||||||||
| groupreduce_2!(backend())(y, x, +, zero(T), Reduction.warp, Val(64); ndrange=n) | ||||||||||||||||||||||
| @test Array(y)[1] == 64 | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is that needed? Shouldn't the backend go and use warp reductions if it can?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm now doing an auto-selection of the algorithm based on device function
__supports_warp_reduction().