@@ -23,6 +23,17 @@ namespace subgroup2
23
23
namespace impl
24
24
{
25
25
26
+ // forward declarations
27
+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
28
+ struct inclusive_scan;
29
+
30
+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
31
+ struct exclusive_scan;
32
+
33
+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
34
+ struct reduction;
35
+
36
+
26
37
// BinOp needed to specialize native
27
38
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
28
39
struct inclusive_scan
@@ -31,7 +42,7 @@ struct inclusive_scan
31
42
using scalar_t = typename Params::scalar_t;
32
43
using binop_t = typename Params::binop_t;
33
44
// assert binop_t == BinOp
34
- using exclusive_scan_op_t = subgroup::impl:: exclusive_scan<binop_t, native>;
45
+ using exclusive_scan_op_t = exclusive_scan<Params, binop_t, 1 , native>;
35
46
36
47
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
37
48
@@ -43,7 +54,7 @@ struct inclusive_scan
43
54
[unroll]
44
55
for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
45
56
retval[i] = binop (retval[i-1 ], value[i]);
46
-
57
+
47
58
exclusive_scan_op_t op;
48
59
scalar_t exclusive = op (retval[ItemsPerInvocation-1 ]);
49
60
@@ -60,7 +71,7 @@ struct exclusive_scan
60
71
using type_t = typename Params::type_t;
61
72
using scalar_t = typename Params::scalar_t;
62
73
using binop_t = typename Params::binop_t;
63
- using inclusive_scan_op_t = subgroup2::impl:: inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
74
+ using inclusive_scan_op_t = inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
64
75
65
76
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
66
77
@@ -86,7 +97,7 @@ struct reduction
86
97
using type_t = typename Params::type_t;
87
98
using scalar_t = typename Params::scalar_t;
88
99
using binop_t = typename Params::binop_t;
89
- using op_t = subgroup::impl:: reduction<binop_t, native>;
100
+ using op_t = reduction<Params, binop_t, 1 , native>;
90
101
91
102
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
92
103
@@ -142,25 +153,25 @@ struct inclusive_scan<Params, BinOp, 1, false>
142
153
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
143
154
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
144
155
145
- type_t operator ()(type_t value)
156
+ scalar_t operator ()(scalar_t value)
146
157
{
147
158
return __call (value);
148
159
}
149
160
150
- static type_t __call (type_t value)
161
+ static scalar_t __call (scalar_t value)
151
162
{
152
163
binop_t op;
153
164
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
154
-
155
- type_t rhs = glsl::subgroupShuffleUp<type_t >(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
165
+
166
+ scalar_t rhs = glsl::subgroupShuffleUp<scalar_t >(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
156
167
value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < 1u));
157
-
168
+
158
169
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
159
170
[unroll]
160
171
for (uint32_t i = 1 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
161
172
{
162
- const uint32_t step = i * 2 ;
163
- rhs = glsl::subgroupShuffleUp<type_t >(value, step);
173
+ const uint32_t step = 1u << i ;
174
+ rhs = glsl::subgroupShuffleUp<scalar_t >(value, step);
164
175
value = op (value, hlsl::mix (rhs, binop_t::identity, subgroupInvocation < step));
165
176
}
166
177
return value;
@@ -174,13 +185,13 @@ struct exclusive_scan<Params, BinOp, 1, false>
174
185
using scalar_t = typename Params::scalar_t;
175
186
using binop_t = typename Params::binop_t;
176
187
177
- type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
188
+ scalar_t operator ()(scalar_t value)
178
189
{
179
190
value = inclusive_scan<Params, BinOp, 1 , false >::__call (value);
180
191
// can't risk getting short-circuited, need to store to a var
181
- type_t left = glsl::subgroupShuffleUp<type_t >(value,1 );
192
+ scalar_t left = glsl::subgroupShuffleUp<scalar_t >(value,1 );
182
193
// the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan
183
- return hlsl:: mix (binop_t::identity, left, bool (glsl::gl_SubgroupInvocationID ())) ;
194
+ return bool (glsl::gl_SubgroupInvocationID ()) ? left:binop_t::identity ;
184
195
}
185
196
};
186
197
@@ -190,11 +201,21 @@ struct reduction<Params, BinOp, 1, false>
190
201
using type_t = typename Params::type_t;
191
202
using scalar_t = typename Params::scalar_t;
192
203
using binop_t = typename Params::binop_t;
204
+ using config_t = typename Params::config_t;
193
205
194
- scalar_t operator ()(NBL_CONST_REF_ARG (type_t) value)
206
+ // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
207
+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
208
+
209
+ scalar_t operator ()(scalar_t value)
195
210
{
196
- // take the last subgroup invocation's value for the reduction
197
- return subgroup::BroadcastLast<type_t>(inclusive_scan<Params, BinOp, 1 , false >::__call (value));
211
+ binop_t op;
212
+
213
+ const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
214
+ [unroll]
215
+ for (uint32_t i = 0 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
216
+ value = op (glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);
217
+
218
+ return value;
198
219
}
199
220
};
200
221
0 commit comments