@@ -4,7 +4,6 @@ export @groupreduce, @warp_groupreduce
44 @groupreduce op val neutral [groupsize]
55
66Perform group reduction of `val` using `op`.
7- If backend supports warp reduction, it will use it instead of thread reduction.
87
98# Arguments
109
@@ -27,13 +26,6 @@ macro groupreduce(op, val, groupsize)
2726 :(__thread_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), Val ($ (esc (groupsize)))))
2827end
2928
30- macro warp_groupreduce (op, val, neutral)
31- :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
32- end
33- macro warp_groupreduce (op, val, neutral, groupsize)
34- :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val ($ (esc (groupsize)))))
35- end
36-
3729function __thread_groupreduce (__ctx__, op, val:: T , :: Val{groupsize} ) where {T, groupsize}
3830 storage = @localmem T groupsize
3931
6153
6254# Warp groupreduce.
6355
64- # NOTE: Backends should implement these two device functions (with `@device_override`).
56+ """
57+ @warp_groupreduce op val neutral [groupsize]
58+
59+ Perform group reduction of `val` using `op`.
60+ Each warp within a workgroup performs its own reduction using [`shfl_down`](@ref) intrinsic,
61+ followed by final reduction over results of individual warp reductions.
62+
63+ !!! note
64+
65+ Use [`supports_warp_reduction`](@ref) to query if given backend supports warp reduction.
66+ """
67+ macro warp_groupreduce (op, val, neutral)
68+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val (prod ($ groupsize ($ (esc (:__ctx__ )))))))
69+ end
70+ macro warp_groupreduce (op, val, neutral, groupsize)
71+ :(__warp_groupreduce ($ (esc (:__ctx__ )), $ (esc (op)), $ (esc (val)), $ (esc (neutral)), Val ($ (esc (groupsize)))))
72+ end
73+
74+ """
75+ shfl_down(val::T, offset::Integer)::T where T
76+
77+ Read `val` from a lane with higher id given by `offset`.
78+ """
6579function shfl_down end
6680supports_warp_reduction () = false
67- # Host-variant.
81+
82+ """
83+ supports_warp_reduction(::Backend)
84+
85+ Query if given backend supports [`shfl_down`](@ref) intrinsic and thus warp reduction.
86+ """
6887supports_warp_reduction (:: Backend ) = false
6988
7089# Assume warp is 32 lanes.
0 commit comments