1- export @groupreduce
2-
3- module Reduction
4- const thread = Val (:thread )
5- const warp = Val (:warp )
6- end
1+ export @groupreduce , @warp_groupreduce
72
83"""
94 @groupreduce op val neutral [groupsize]
@@ -25,55 +20,21 @@ If backend supports warp reduction, it will use it instead of thread reduction.
2520
2621Result of the reduction.
2722"""
28- macro groupreduce (op, val, neutral)
29- return quote
30- if __supports_warp_reduction ()
31- __groupreduce (
32- $ (esc (:__ctx__ )),
33- $ (esc (op)),
34- $ (esc (val)),
35- $ (esc (neutral)),
36- Val (prod ($ groupsize ($ (esc (:__ctx__ ))))),
37- $ (esc (Reduction. warp)),
38- )
39- else
40- __groupreduce (
41- $ (esc (:__ctx__ )),
42- $ (esc (op)),
43- $ (esc (val)),
44- $ (esc (neutral)),
45- Val (prod ($ groupsize ($ (esc (:__ctx__ ))))),
46- $ (esc (Reduction. thread)),
47- )
48- end
49- end
23+ macro groupreduce (op, val)
24+ :(__thread_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
25+ end
26+ macro groupreduce (op, val, groupsize)
27+ :(__thread_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), Val ($ (esc (groupsize)))))
5028end
5129
52- macro groupreduce (op, val, neutral, groupsize)
53- return quote
54- if __supports_warp_reduction ()
55- __groupreduce (
56- $ (esc (:__ctx__ )),
57- $ (esc (op)),
58- $ (esc (val)),
59- $ (esc (neutral)),
60- Val ($ (esc (groupsize))),
61- $ (esc (Reduction. warp)),
62- )
63- else
64- __groupreduce (
65- $ (esc (:__ctx__ )),
66- $ (esc (op)),
67- $ (esc (val)),
68- $ (esc (neutral)),
69- Val ($ (esc (groupsize))),
70- $ (esc (Reduction. thread)),
71- )
72- end
73- end
30+ macro warp_groupreduce (op, val, neutral)
31+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
32+ end
33+ macro warp_groupreduce (op, val, neutral, groupsize)
34+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val ($ (esc (groupsize)))))
7435end
7536
76- function __groupreduce (__ctx__, op, val:: T , neutral :: T , :: Val{groupsize} , :: Val{:thread } ) where {T, groupsize}
37+ function __thread_groupreduce (__ctx__, op, val:: T , :: Val{groupsize} ) where {T, groupsize}
7738 storage = @localmem T groupsize
7839
7940 local_idx = @index (Local)
@@ -120,7 +81,7 @@ const __warp_bins = UInt32(32)
12081 return val
12182end
12283
123- function __groupreduce (__ctx__, op, val:: T , neutral:: T , :: Val{groupsize} , :: Val{:warp } ) where {T, groupsize}
84+ function __warp_groupreduce (__ctx__, op, val:: T , neutral:: T , :: Val{groupsize} ) where {T, groupsize}
12485 storage = @localmem T __warp_bins
12586
12687 local_idx = @index (Local)
0 commit comments