Skip to content

Commit d4ca075

Browse files
committed
fixes to subgroup2 funcs
1 parent 9401bf3 commit d4ca075

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ struct ArithmeticParams
2828
using type_t = vector<scalar_t, _ItemsPerInvocation>;// conditional_t<_ItemsPerInvocation<2, scalar_t, vector<scalar_t, _ItemsPerInvocation> >;
2929

3030
NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
31-
// if OverrideUseNativeInstrinsics is true, tries to use native spirv intrinsics
32-
// if OverrideUseNativeInstrinsics is false, will always use emulated versions
3331
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
3432
};
3533

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ namespace subgroup2
2323
namespace impl
2424
{
2525

26+
// forward declarations
27+
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
28+
struct inclusive_scan;
29+
30+
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
31+
struct exclusive_scan;
32+
33+
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
34+
struct reduction;
35+
36+
2637
// BinOp needed to specialize native
2738
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
2839
struct inclusive_scan
@@ -31,7 +42,7 @@ struct inclusive_scan
3142
using scalar_t = typename Params::scalar_t;
3243
using binop_t = typename Params::binop_t;
3344
// assert binop_t == BinOp
34-
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
45+
using exclusive_scan_op_t = exclusive_scan<Params, binop_t, 1, native>;
3546

3647
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
3748

@@ -43,7 +54,7 @@ struct inclusive_scan
4354
[unroll]
4455
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
4556
retval[i] = binop(retval[i-1], value[i]);
46-
57+
4758
exclusive_scan_op_t op;
4859
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);
4960

@@ -60,7 +71,7 @@ struct exclusive_scan
6071
using type_t = typename Params::type_t;
6172
using scalar_t = typename Params::scalar_t;
6273
using binop_t = typename Params::binop_t;
63-
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
74+
using inclusive_scan_op_t = inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
6475

6576
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
6677

@@ -86,7 +97,7 @@ struct reduction
8697
using type_t = typename Params::type_t;
8798
using scalar_t = typename Params::scalar_t;
8899
using binop_t = typename Params::binop_t;
89-
using op_t = subgroup::impl::reduction<binop_t, native>;
100+
using op_t = reduction<Params, binop_t, 1, native>;
90101

91102
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
92103

@@ -142,25 +153,25 @@ struct inclusive_scan<Params, BinOp, 1, false>
142153
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
143154
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
144155

145-
type_t operator()(type_t value)
156+
scalar_t operator()(scalar_t value)
146157
{
147158
return __call(value);
148159
}
149160

150-
static type_t __call(type_t value)
161+
static scalar_t __call(scalar_t value)
151162
{
152163
binop_t op;
153164
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
154-
155-
type_t rhs = glsl::subgroupShuffleUp<type_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
165+
166+
scalar_t rhs = glsl::subgroupShuffleUp<scalar_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
156167
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u));
157-
168+
158169
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
159170
[unroll]
160171
for (uint32_t i = 1; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
161172
{
162-
const uint32_t step = i * 2;
163-
rhs = glsl::subgroupShuffleUp<type_t>(value, step);
173+
const uint32_t step = 1u << i;
174+
rhs = glsl::subgroupShuffleUp<scalar_t>(value, step);
164175
value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step));
165176
}
166177
return value;
@@ -174,13 +185,13 @@ struct exclusive_scan<Params, BinOp, 1, false>
174185
using scalar_t = typename Params::scalar_t;
175186
using binop_t = typename Params::binop_t;
176187

177-
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
188+
scalar_t operator()(scalar_t value)
178189
{
179190
value = inclusive_scan<Params, BinOp, 1, false>::__call(value);
180191
// can't risk getting short-circuited, need to store to a var
181-
type_t left = glsl::subgroupShuffleUp<type_t>(value,1);
192+
scalar_t left = glsl::subgroupShuffleUp<scalar_t>(value,1);
182193
// the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan
183-
return hlsl::mix(binop_t::identity, left, bool(glsl::gl_SubgroupInvocationID()));
194+
return bool(glsl::gl_SubgroupInvocationID()) ? left:binop_t::identity;
184195
}
185196
};
186197

@@ -190,11 +201,21 @@ struct reduction<Params, BinOp, 1, false>
190201
using type_t = typename Params::type_t;
191202
using scalar_t = typename Params::scalar_t;
192203
using binop_t = typename Params::binop_t;
204+
using config_t = typename Params::config_t;
193205

194-
scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
206+
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
207+
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
208+
209+
scalar_t operator()(scalar_t value)
195210
{
196-
// take the last subgroup invocation's value for the reduction
197-
return subgroup::BroadcastLast<type_t>(inclusive_scan<Params, BinOp, 1, false>::__call(value));
211+
binop_t op;
212+
213+
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
214+
[unroll]
215+
for (uint32_t i = 0; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
216+
value = op(glsl::subgroupShuffleXor<scalar_t>(value,0x1u<<i),value);
217+
218+
return value;
198219
}
199220
};
200221

0 commit comments

Comments
 (0)