- 
                Notifications
    You must be signed in to change notification settings 
- Fork 79
Implement groupreduce API #559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
e1a110f
              ff4097f
              6a35eb8
              224e8c8
              4a8e707
              7c923fb
              a647992
              cbc8bd5
              bb77270
              db5abc5
              618c840
              344d484
              1cd2d2f
              7c96e5a
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| export @groupreduce | ||
|  | ||
| module Reduction | ||
| const thread = Val(:thread) | ||
| const warp = Val(:warp) | ||
| end | ||
|  | ||
| """ | ||
| @groupreduce op val neutral [groupsize] | ||
|  | ||
| Perform group reduction of `val` using `op`. | ||
| If backend supports warp reduction, it will use it instead of thread reduction. | ||
|  | ||
| # Arguments | ||
|  | ||
| - `neutral` should be a neutral w.r.t. `op`, such that `op(neutral, x) == x`. | ||
|  | ||
| - `groupsize` specifies size of the workgroup. | ||
| If a kernel does not specifies `groupsize` statically, then it is required to | ||
| provide `groupsize`. | ||
| Also can be used to perform reduction accross first `groupsize` threads | ||
| (if `groupsize < @groupsize()`). | ||
|  | ||
| # Returns | ||
|  | ||
| Result of the reduction. | ||
| """ | ||
| macro groupreduce(op, val, neutral) | ||
| return quote | ||
| if __supports_warp_reduction() | ||
| __groupreduce( | ||
| $(esc(:__ctx__)), | ||
| $(esc(op)), | ||
| $(esc(val)), | ||
| $(esc(neutral)), | ||
| Val(prod($groupsize($(esc(:__ctx__))))), | ||
| $(esc(Reduction.warp)), | ||
| ) | ||
| else | ||
| __groupreduce( | ||
| $(esc(:__ctx__)), | ||
| $(esc(op)), | ||
| $(esc(val)), | ||
| $(esc(neutral)), | ||
| Val(prod($groupsize($(esc(:__ctx__))))), | ||
| $(esc(Reduction.thread)), | ||
| ) | ||
| end | ||
| end | ||
| end | ||
|  | ||
| macro groupreduce(op, val, neutral, groupsize) | ||
| return quote | ||
| if __supports_warp_reduction() | ||
| __groupreduce( | ||
| $(esc(:__ctx__)), | ||
| $(esc(op)), | ||
| $(esc(val)), | ||
| $(esc(neutral)), | ||
| Val($(esc(groupsize))), | ||
| $(esc(Reduction.warp)), | ||
| ) | ||
| else | ||
| __groupreduce( | ||
| $(esc(:__ctx__)), | ||
| $(esc(op)), | ||
| $(esc(val)), | ||
| $(esc(neutral)), | ||
| Val($(esc(groupsize))), | ||
| $(esc(Reduction.thread)), | ||
| ) | ||
| end | ||
| end | ||
| end | ||
|  | ||
| function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:thread}) where {T, groupsize} | ||
| storage = @localmem T groupsize | ||
|  | ||
| local_idx = @index(Local) | ||
| @inbounds local_idx ≤ groupsize && (storage[local_idx] = val) | ||
| @synchronize() | ||
|  | ||
| s::UInt64 = groupsize ÷ 0x02 | ||
| while s > 0x00 | ||
| if (local_idx - 0x01) < s | ||
| other_idx = local_idx + s | ||
| if other_idx ≤ groupsize | ||
| @inbounds storage[local_idx] = op(storage[local_idx], storage[other_idx]) | ||
| end | ||
| end | ||
| @synchronize() | ||
| s >>= 0x01 | ||
| end | ||
|  | ||
| if local_idx == 0x01 | ||
| @inbounds val = storage[local_idx] | ||
| end | ||
| return val | ||
| end | ||
|  | ||
| # Warp groupreduce. | ||
|  | ||
| # NOTE: Backends should implement these two device functions (with `@device_override`). | ||
| function __shfl_down end | ||
| function __supports_warp_reduction() end | ||
|  | ||
| # Assume warp is 32 lanes. | ||
| const __warpsize = UInt32(32) | ||
| # Maximum number of warps (for a groupsize = 1024). | ||
| const __warp_bins = UInt32(32) | ||
|  | ||
| @inline function __warp_reduce(val, op) | ||
| offset::UInt32 = __warpsize ÷ 0x02 | ||
| while offset > 0x00 | ||
| val = op(val, __shfl_down(val, offset)) | ||
| offset >>= 0x01 | ||
| end | ||
| return val | ||
| end | ||
|  | ||
| function __groupreduce(__ctx__, op, val::T, neutral::T, ::Val{groupsize}, ::Val{:warp}) where {T, groupsize} | ||
| storage = @localmem T __warp_bins | ||
|  | ||
| local_idx = @index(Local) | ||
| lane = (local_idx - 0x01) % __warpsize + 0x01 | ||
| warp_id = (local_idx - 0x01) ÷ __warpsize + 0x01 | ||
|  | ||
| # Each warp performs a reduction and writes results into its own bin in `storage`. | ||
| val = __warp_reduce(val, op) | ||
| @inbounds lane == 0x01 && (storage[warp_id] = val) | ||
| @synchronize() | ||
|  | ||
| # Final reduction of the `storage` on the first warp. | ||
| within_storage = (local_idx - 0x01) < groupsize ÷ __warpsize | ||
| @inbounds val = within_storage ? storage[lane] : neutral | ||
| warp_id == 0x01 && (val = __warp_reduce(val, op)) | ||
| return val | ||
| end | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,35 @@ | ||||||||||||||||||||||||||
| @kernel cpu=false function groupreduce_1!(y, x, op, neutral) | ||||||||||||||||||||||||||
| i = @index(Global) | ||||||||||||||||||||||||||
| 
      Comment on lines
    
      +1
     to 
      +2
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 
      Comment on lines
    
      +1
     to 
      +2
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 
      Comment on lines
    
      +1
     to 
      +2
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 | ||||||||||||||||||||||||||
| val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||||||
| res = @groupreduce(op, val, neutral) | ||||||||||||||||||||||||||
| i == 1 && (y[1] = res) | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| @kernel cpu=false function groupreduce_2!(y, x, op, neutral, ::Val{groupsize}) where {groupsize} | ||||||||||||||||||||||||||
| i = @index(Global) | ||||||||||||||||||||||||||
| val = i > length(x) ? neutral : x[i] | ||||||||||||||||||||||||||
| res = @groupreduce(op, val, neutral, groupsize) | ||||||||||||||||||||||||||
| i == 1 && (y[1] = res) | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| function groupreduce_testsuite(backend, AT) | ||||||||||||||||||||||||||
| # TODO should be a better way of querying max groupsize | ||||||||||||||||||||||||||
| groupsizes = "$backend" == "oneAPIBackend" ? | ||||||||||||||||||||||||||
| (256,) : | ||||||||||||||||||||||||||
| (256, 512, 1024) | ||||||||||||||||||||||||||
| @testset "@groupreduce" begin | ||||||||||||||||||||||||||
|         
                  pxl-th marked this conversation as resolved.
              Show resolved
            Hide resolved         
                  pxl-th marked this conversation as resolved.
              Show resolved
            Hide resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
         
                  pxl-th marked this conversation as resolved.
              Show resolved
            Hide resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
        Suggested change
       
 | ||||||||||||||||||||||||||
| @testset "T=$T, n=$n" for T in (Float16, Float32, Float64, Int16, Int32, Int64), n in groupsizes | ||||||||||||||||||||||||||
| x = AT(ones(T, n)) | ||||||||||||||||||||||||||
| y = AT(zeros(T, 1)) | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| groupreduce_1!(backend(), n)(y, x, +, zero(T); ndrange = n) | ||||||||||||||||||||||||||
| @test Array(y)[1] == n | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| groupreduce_2!(backend())(y, x, +, zero(T), Val(128); ndrange = n) | ||||||||||||||||||||||||||
| @test Array(y)[1] == 128 | ||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||
| groupreduce_2!(backend())(y, x, +, zero(T), Val(64); ndrange = n) | ||||||||||||||||||||||||||
| @test Array(y)[1] == 64 | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this is not legal.
#262 might need to wait until #556
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I assume this code is GPU only anyways)