11@kernel  cpu= false  function  groupreduce_1! (y, x, op, neutral)
22    i =  @index (Global)
33    val =  i >  length (x) ?  neutral :  x[i]
4-     res =  @groupreduce (op, val, neutral )
4+     res =  @groupreduce (op, val)
55    i ==  1  &&  (y[1 ] =  res)
66end 
77
88@kernel  cpu= false  function  groupreduce_2! (y, x, op, neutral, :: Val{groupsize} ) where  {groupsize}
99    i =  @index (Global)
1010    val =  i >  length (x) ?  neutral :  x[i]
11-     res =  @groupreduce (op, val, neutral, groupsize)
11+     res =  @groupreduce (op, val, groupsize)
12+     i ==  1  &&  (y[1 ] =  res)
13+ end 
14+ 
15+ @kernel  cpu= false  function  warp_groupreduce_1! (y, x, op, neutral)
16+     i =  @index (Global)
17+     val =  i >  length (x) ?  neutral :  x[i]
18+     res =  @warp_groupreduce (op, val, neutral)
19+     i ==  1  &&  (y[1 ] =  res)
20+ end 
21+ 
22+ @kernel  cpu= false  function  warp_groupreduce_2! (y, x, op, neutral, :: Val{groupsize} ) where  {groupsize}
23+     i =  @index (Global)
24+     val =  i >  length (x) ?  neutral :  x[i]
25+     res =  @warp_groupreduce (op, val, neutral, groupsize)
1226    i ==  1  &&  (y[1 ] =  res)
1327end 
1428
@@ -17,19 +31,40 @@ function groupreduce_testsuite(backend, AT)
1731    groupsizes =  " $backend "   ==  " oneAPIBackend"   ? 
1832        (256 ,) : 
1933        (256 , 512 , 1024 )
34+ 
2035    @testset  " @groupreduce"   begin 
2136        @testset  " T=$T , n=$n "   for  T in  (Float16, Float32, Int16, Int32, Int64), n in  groupsizes
2237            x =  AT (ones (T, n))
2338            y =  AT (zeros (T, 1 ))
39+             neutral =  zero (T)
40+             op =  + 
2441
25-             groupreduce_1! (backend (), n)(y, x, + ,  zero (T) ; ndrange =  n)
42+             groupreduce_1! (backend (), n)(y, x, op, neutral ; ndrange =  n)
2643            @test  Array (y)[1 ] ==  n
2744
28-             groupreduce_2! (backend ())(y, x, + , zero (T), Val (128 ); ndrange =  n)
29-             @test  Array (y)[1 ] ==  128 
45+             for  groupsize in  (64 , 128 )
46+                 groupreduce_2! (backend ())(y, x, op, neutral, Val (groupsize); ndrange =  n)
47+                 @test  Array (y)[1 ] ==  groupsize
48+             end 
49+         end 
50+     end 
51+ 
52+     if  KernelAbstractions. supports_warp_reduction (backend ())
53+         @testset  " @warp_groupreduce"   begin 
54+             @testset  " T=$T , n=$n "   for  T in  (Float16, Float32, Int16, Int32, Int64), n in  groupsizes
55+                 x =  AT (ones (T, n))
56+                 y =  AT (zeros (T, 1 ))
57+                 neutral =  zero (T)
58+                 op =  + 
59+ 
60+                 warp_groupreduce_1! (backend (), n)(y, x, op, neutral; ndrange =  n)
61+                 @test  Array (y)[1 ] ==  n
3062
31-             groupreduce_2! (backend ())(y, x, + , zero (T), Val (64 ); ndrange =  n)
32-             @test  Array (y)[1 ] ==  64 
63+                 for  groupsize in  (64 , 128 )
64+                     warp_groupreduce_2! (backend ())(y, x, op, neutral, Val (groupsize); ndrange =  n)
65+                     @test  Array (y)[1 ] ==  groupsize
66+                 end 
67+             end 
3368        end 
3469    end 
3570end 
0 commit comments