- 
                Notifications
    
You must be signed in to change notification settings  - Fork 80
 
groupreduction and subgroupreduction #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
cc510ac
              1c1e459
              dd3a0ca
              546e8c9
              3602808
              42a7960
              c96a24a
              d2d65be
              128a5f0
              b899685
              1cdb6d6
              1fea4cc
              41356d3
              88662f8
              45844ce
              c5dc356
              e2c8f84
              700d5f2
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| export @groupreduce, @subgroupreduce | ||
| 
     | 
||
| """ | ||
| @subgroupreduce(op, val) | ||
| reduce values across a subgroup. This operation is only supported if subgroups are supported by the backend. | ||
| """ | ||
| macro subgroupreduce(op, val) | ||
| quote | ||
| $__subgroupreduce($(esc(op)),$(esc(val))) | ||
| end | ||
| end | ||
| 
     | 
||
| function __subgroupreduce(op, val) | ||
| error("@subgroupreduce used outside kernel, not captured, or not supported") | ||
| end | ||
| 
     | 
||
| """ | ||
| @groupreduce(op, val, neutral, use_subgroups) | ||
| Reduce values across a block | ||
| - `op`: the operator of the reduction | ||
| - `val`: value that each thread contibutes to the values that need to be reduced | ||
| - `netral`: value of the operator, so that `op(netural, neutral) = neutral`` | ||
| - `use_subgroups`: make use of the subgroupreduction of the groupreduction | ||
                
       | 
||
| """ | ||
| macro groupreduce(op, val, neutral, use_subgroups) | ||
| quote | ||
| $__groupreduce($(esc(:__ctx__)),$(esc(op)), $(esc(val)), $(esc(neutral)), $(esc(typeof(val))), Val(use_subgroups)) | ||
| end | ||
| end | ||
| 
     | 
||
| @inline function __groupreduce(__ctx__, op, val, neutral, ::Type{T}, ::Val{true}) where {T} | ||
| idx_in_group = @index(Local) | ||
| groupsize = @groupsize()[1] | ||
| subgroupsize = @subgroupsize() | ||
| 
     | 
||
| localmem = @localmem(T, subgroupsize) | ||
| 
     | 
||
| idx_subgroup, idx_in_subgroup = fldmod1(idx_in_group, subgroupsize) | ||
| 
     | 
||
| # first subgroup reduction | ||
| val = @subgroupreduce(op, val) | ||
| 
     | 
||
| # store partial results in local memory | ||
| if idx_in_subgroup == 1 | ||
| @inbounds localmem[idx_in_subgroup] = val | ||
| end | ||
| 
     | 
||
| @synchronize() | ||
| 
     | 
||
| val = if idx_in_subgroup <= fld1(groupsize, subgroupsize) | ||
| @inbounds localmem[idx_in_subgroup] | ||
| else | ||
| neutral | ||
| end | ||
| 
     | 
||
| # second subgroup reduction to reduce partial results | ||
| if idx_in_subgroup == 1 | ||
| val = @subgroupreduce(op, val) | ||
| end | ||
| 
     | 
||
| return val | ||
| end | ||
| 
     | 
||
| @inline function __groupreduce(__ctx__, op, val, neutral, ::Type{T}, ::Val{false}) where {T} | ||
| idx_in_group = @index(Local) | ||
| groupsize = @groupsize()[1] | ||
| 
     | 
||
| localmem = @localmem(T, groupsize) | ||
                
       | 
||
| 
     | 
||
| @inbounds localmem[idx_in_group] = val | ||
| 
     | 
||
| # perform the reduction | ||
| d = 1 | ||
| while d < groupsize | ||
| @synchronize() | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Workaround?  | 
||
| index = 2 * d * (idx_in_group-1) + 1 | ||
| @inbounds if index <= groupsize | ||
| other_val = if index + d <= groupsize | ||
| localmem[index+d] | ||
| else | ||
| neutral | ||
| end | ||
| localmem[index] = op(localmem[index], other_val) | ||
| end | ||
| d *= 2 | ||
| end | ||
| 
     | 
||
| # load the final value on the first thread | ||
| if idx_in_group == 1 | ||
| val = @inbounds localmem[idx_in_group] | ||
| end | ||
| 
     | 
||
| return val | ||
| end | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| using KernelAbstractions, Test | ||
| 
     | 
||
| 
     | 
||
| 
     | 
||
| 
     | 
||
| @kernel function reduce(a, b, op, neutral) | ||
| idx_in_group = @index(Local) | ||
| 
     | 
||
| val = a[idx_in_group] | ||
| 
     | 
||
| val = @groupreduce(op, val, netral, false) | ||
| 
     | 
||
| b[1] = val | ||
| end | ||
| 
     | 
||
| function(backend, ArrayT) | ||
| @testset "groupreduce one group" begin | ||
| @testset for op in (+,*,max,min) | ||
| @testset for type in (Int32, Float32, Float64) | ||
| @test test_1group_groupreduce(backend, ArrayT ,op, type, op(neutral)) | ||
| end | ||
| end | ||
| end | ||
| end | ||
| 
     | 
||
| function test_1group_groupreduce(backend,ArrayT, op, type, neutral) | ||
| a = rand(type, 32) | ||
| b = ArrayT(a) | ||
| 
     | 
||
| c = similar(b,1) | ||
| reduce(a, c, op, neutral) | ||
| 
     | 
||
| expected = mapreduce(x->x^2, +, a) | ||
| actual = c[1] | ||
| return expected = actual | ||
| end | ||
| 
     | 
||
| 
     | 
||
| 
     | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.