4
4
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
5
5
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
6
6
7
- // #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8
- // #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
7
+ #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8
+ #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
9
9
10
- // #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
10
+ #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
11
+ #include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
11
12
12
- // #include "nbl/builtin/hlsl/functional.hlsl"
13
+ #include "nbl/builtin/hlsl/functional.hlsl"
13
14
14
- #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
15
+ // #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
15
16
16
17
namespace nbl
17
18
{
@@ -23,12 +24,14 @@ namespace subgroup2
23
24
namespace impl
24
25
{
25
26
26
- template<class Params, uint32_t ItemsPerInvocation, bool native>
27
+ // BinOp needed to specialize native
28
+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
27
29
struct inclusive_scan
28
30
{
29
31
using type_t = typename Params::type_t;
30
32
using scalar_t = typename Params::scalar_t;
31
33
using binop_t = typename Params::binop_t;
34
+ // assert binop_t == BinOp
32
35
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
33
36
34
37
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -52,13 +55,13 @@ struct inclusive_scan
52
55
}
53
56
};
54
57
55
- template<class Params, uint32_t ItemsPerInvocation, bool native>
58
+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
56
59
struct exclusive_scan
57
60
{
58
61
using type_t = typename Params::type_t;
59
62
using scalar_t = typename Params::scalar_t;
60
63
using binop_t = typename Params::binop_t;
61
- using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, ItemsPerInvocation, native>;
64
+ using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
62
65
63
66
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
64
67
@@ -78,7 +81,7 @@ struct exclusive_scan
78
81
}
79
82
};
80
83
81
- template<class Params, uint32_t ItemsPerInvocation, bool native>
84
+ template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
82
85
struct reduction
83
86
{
84
87
using type_t = typename Params::type_t;
@@ -103,74 +106,98 @@ struct reduction
103
106
104
107
// specs for N=1 uses subgroup funcs
105
108
// specialize native
106
- // #define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<typename T> struct NAME<BINOP<T>,true> \
107
- // { \
108
- // using type_t = T; \
109
- // \
110
- // type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
111
- // }
109
+ #define SPECIALIZE (NAME,BINOP,SUBGROUP_OP) template<class Params, typename T> struct NAME<Params, BINOP<T>, 1 ,true > \
110
+ { \
111
+ using type_t = T; \
112
+ \
113
+ type_t operator ()(NBL_CONST_REF_ARG (type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
114
+ }
112
115
113
- // #define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
114
- // SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
115
- // SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116
+ #define SPECIALIZE_ALL (BINOP,SUBGROUP_OP) SPECIALIZE (reduction,BINOP,SUBGROUP_OP); \
117
+ SPECIALIZE (inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
118
+ SPECIALIZE (exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116
119
117
- // SPECIALIZE_ALL(bit_and,And);
118
- // SPECIALIZE_ALL(bit_or,Or);
119
- // SPECIALIZE_ALL(bit_xor,Xor);
120
+ SPECIALIZE_ALL (bit_and,And);
121
+ SPECIALIZE_ALL (bit_or,Or);
122
+ SPECIALIZE_ALL (bit_xor,Xor);
120
123
121
- // SPECIALIZE_ALL(plus,Add);
122
- // SPECIALIZE_ALL(multiplies,Mul);
124
+ SPECIALIZE_ALL (plus,Add );
125
+ SPECIALIZE_ALL (multiplies,Mul);
123
126
124
- // SPECIALIZE_ALL(minimum,Min);
125
- // SPECIALIZE_ALL(maximum,Max);
127
+ SPECIALIZE_ALL (minimum,Min );
128
+ SPECIALIZE_ALL (maximum,Max );
126
129
127
- // #undef SPECIALIZE_ALL
128
- // #undef SPECIALIZE
130
+ #undef SPECIALIZE_ALL
131
+ #undef SPECIALIZE
129
132
130
133
// specialize portability
131
- template<class Params, bool native >
132
- struct inclusive_scan<Params, 1 , native >
134
+ template<class Params, class BinOp >
135
+ struct inclusive_scan<Params, BinOp, 1 , false >
133
136
{
134
137
using type_t = typename Params::type_t;
135
138
using scalar_t = typename Params::scalar_t;
136
139
using binop_t = typename Params::binop_t;
137
- using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
138
140
// assert T == scalar type, binop::type == T
141
+ using config_t = typename Params::config_t;
139
142
140
- type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
143
+ // affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
144
+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
145
+
146
+ type_t operator ()(type_t value)
141
147
{
142
- op_t op;
143
- return op (value);
148
+ return __call (value);
149
+ }
150
+
151
+ static type_t __call (type_t value)
152
+ {
153
+ binop_t op;
154
+ const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID ();
155
+
156
+ type_t rhs = glsl::subgroupShuffleUp<type_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
157
+ // TODO waiting on mix intrinsic fix from bxdf branch, value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u));
158
+ value = op (value, subgroupInvocation<1u ? binop_t::identity : rhs);
159
+
160
+ const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
161
+ [unroll]
162
+ for (uint32_t i = 1 ; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
163
+ {
164
+ const uint32_t step = i * 2 ;
165
+ rhs = glsl::subgroupShuffleUp<type_t>(value, step);
166
+ // TODO value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step));
167
+ value = op (value, subgroupInvocation<step ? binop_t::identity : rhs);
168
+ }
169
+ return value;
144
170
}
145
171
};
146
172
147
- template<class Params, bool native >
148
- struct exclusive_scan<Params, 1 , native >
173
+ template<class Params, class BinOp >
174
+ struct exclusive_scan<Params, BinOp, 1 , false >
149
175
{
150
176
using type_t = typename Params::type_t;
151
177
using scalar_t = typename Params::scalar_t;
152
178
using binop_t = typename Params::binop_t;
153
- using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
154
179
155
180
type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
156
181
{
157
- op_t op;
158
- return op (value);
182
+ value = inclusive_scan<Params, BinOp, 1 , false >::__call (value);
183
+ // can't risk getting short-circuited, need to store to a var
184
+ type_t left = glsl::subgroupShuffleUp<type_t>(value,1 );
185
+ // the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan
186
+ return bool (glsl::gl_SubgroupInvocationID ()) ? left:binop_t::identity;
159
187
}
160
188
};
161
189
162
- template<class Params, bool native >
163
- struct reduction<Params, 1 , native >
190
+ template<class Params, class BinOp >
191
+ struct reduction<Params, BinOp, 1 , false >
164
192
{
165
193
using type_t = typename Params::type_t;
166
194
using scalar_t = typename Params::scalar_t;
167
195
using binop_t = typename Params::binop_t;
168
- using op_t = subgroup::impl::reduction<binop_t, native>;
169
196
170
197
scalar_t operator ()(NBL_CONST_REF_ARG (type_t) value)
171
198
{
172
- op_t op;
173
- return op ( value);
199
+ // take the last subgroup invocation's value for the reduction
200
+ return subgroup::BroadcastLast<type_t>(inclusive_scan<Params, BinOp, 1 , false >:: __call ( value) );
174
201
}
175
202
};
176
203
0 commit comments