Skip to content

Commit 237ac09

Browse files
committed
rework specializations for native, emulated funcs
1 parent a8e02a3 commit 237ac09

File tree

2 files changed

+75
-48
lines changed

2 files changed

+75
-48
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace hlsl
1919
namespace subgroup2
2020
{
2121

22-
template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void, bool OverrideUseNativeInstrinsics=true NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
22+
template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
2323
struct ArithmeticParams
2424
{
2525
using config_t = Config;
@@ -30,15 +30,15 @@ struct ArithmeticParams
3030
NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
3131
// if OverrideUseNativeInstrinsics is true, tries to use native spirv intrinsics
3232
// if OverrideUseNativeInstrinsics is false, will always use emulated versions
33-
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic && OverrideUseNativeInstrinsics /*&& /*some heuristic for when its faster*/;
33+
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
3434
};
3535

3636
template<typename Params>
37-
struct reduction : impl::reduction<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
37+
struct reduction : impl::reduction<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
3838
template<typename Params>
39-
struct inclusive_scan : impl::inclusive_scan<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
39+
struct inclusive_scan : impl::inclusive_scan<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
4040
template<typename Params>
41-
struct exclusive_scan : impl::exclusive_scan<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
41+
struct exclusive_scan : impl::exclusive_scan<Params,typename Params::binop_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
4242

4343
}
4444
}

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
66

7-
// #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8-
// #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
7+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8+
#include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
99

10-
// #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
10+
#include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
11+
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
1112

12-
// #include "nbl/builtin/hlsl/functional.hlsl"
13+
#include "nbl/builtin/hlsl/functional.hlsl"
1314

14-
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
15+
// #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
1516

1617
namespace nbl
1718
{
@@ -23,12 +24,14 @@ namespace subgroup2
2324
namespace impl
2425
{
2526

26-
template<class Params, uint32_t ItemsPerInvocation, bool native>
27+
// BinOp needed to specialize native
28+
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
2729
struct inclusive_scan
2830
{
2931
using type_t = typename Params::type_t;
3032
using scalar_t = typename Params::scalar_t;
3133
using binop_t = typename Params::binop_t;
34+
// assert binop_t == BinOp
3235
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
3336

3437
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -52,13 +55,13 @@ struct inclusive_scan
5255
}
5356
};
5457

55-
template<class Params, uint32_t ItemsPerInvocation, bool native>
58+
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
5659
struct exclusive_scan
5760
{
5861
using type_t = typename Params::type_t;
5962
using scalar_t = typename Params::scalar_t;
6063
using binop_t = typename Params::binop_t;
61-
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, ItemsPerInvocation, native>;
64+
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, binop_t, ItemsPerInvocation, native>;
6265

6366
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
6467

@@ -78,7 +81,7 @@ struct exclusive_scan
7881
}
7982
};
8083

81-
template<class Params, uint32_t ItemsPerInvocation, bool native>
84+
template<class Params, class BinOp, uint32_t ItemsPerInvocation, bool native>
8285
struct reduction
8386
{
8487
using type_t = typename Params::type_t;
@@ -103,74 +106,98 @@ struct reduction
103106

104107
// specs for N=1 uses subgroup funcs
105108
// specialize native
106-
// #define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<typename T> struct NAME<BINOP<T>,true> \
107-
// { \
108-
// using type_t = T; \
109-
// \
110-
// type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
111-
// }
109+
#define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<class Params, typename T> struct NAME<Params,BINOP<T>,1,true> \
110+
{ \
111+
using type_t = T; \
112+
\
113+
type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
114+
}
112115

113-
// #define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
114-
// SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
115-
// SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116+
#define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
117+
SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
118+
SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116119

117-
// SPECIALIZE_ALL(bit_and,And);
118-
// SPECIALIZE_ALL(bit_or,Or);
119-
// SPECIALIZE_ALL(bit_xor,Xor);
120+
SPECIALIZE_ALL(bit_and,And);
121+
SPECIALIZE_ALL(bit_or,Or);
122+
SPECIALIZE_ALL(bit_xor,Xor);
120123

121-
// SPECIALIZE_ALL(plus,Add);
122-
// SPECIALIZE_ALL(multiplies,Mul);
124+
SPECIALIZE_ALL(plus,Add);
125+
SPECIALIZE_ALL(multiplies,Mul);
123126

124-
// SPECIALIZE_ALL(minimum,Min);
125-
// SPECIALIZE_ALL(maximum,Max);
127+
SPECIALIZE_ALL(minimum,Min);
128+
SPECIALIZE_ALL(maximum,Max);
126129

127-
// #undef SPECIALIZE_ALL
128-
// #undef SPECIALIZE
130+
#undef SPECIALIZE_ALL
131+
#undef SPECIALIZE
129132

130133
// specialize portability
131-
template<class Params, bool native>
132-
struct inclusive_scan<Params, 1, native>
134+
template<class Params, class BinOp>
135+
struct inclusive_scan<Params, BinOp, 1, false>
133136
{
134137
using type_t = typename Params::type_t;
135138
using scalar_t = typename Params::scalar_t;
136139
using binop_t = typename Params::binop_t;
137-
using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
138140
// assert T == scalar type, binop::type == T
141+
using config_t = typename Params::config_t;
139142

140-
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
143+
// affected by https://github.com/microsoft/DirectXShaderCompiler/issues/7006
144+
// NBL_CONSTEXPR_STATIC_INLINE uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
145+
146+
type_t operator()(type_t value)
141147
{
142-
op_t op;
143-
return op(value);
148+
return __call(value);
149+
}
150+
151+
static type_t __call(type_t value)
152+
{
153+
binop_t op;
154+
const uint32_t subgroupInvocation = glsl::gl_SubgroupInvocationID();
155+
156+
type_t rhs = glsl::subgroupShuffleUp<type_t>(value, 1u); // all invocations must execute the shuffle, even if we don't apply the op() to all of them
157+
// TODO waiting on mix intrinsic fix from bxdf branch, value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < 1u));
158+
value = op(value, subgroupInvocation<1u ? binop_t::identity : rhs);
159+
160+
const uint32_t SubgroupSizeLog2 = config_t::SizeLog2;
161+
[unroll]
162+
for (uint32_t i = 1; i < integral_constant<uint32_t,SubgroupSizeLog2>::value; i++)
163+
{
164+
const uint32_t step = i * 2;
165+
rhs = glsl::subgroupShuffleUp<type_t>(value, step);
166+
// TODO value = op(value, hlsl::mix(rhs, binop_t::identity, subgroupInvocation < step));
167+
value = op(value, subgroupInvocation<step ? binop_t::identity : rhs);
168+
}
169+
return value;
144170
}
145171
};
146172

147-
template<class Params, bool native>
148-
struct exclusive_scan<Params, 1, native>
173+
template<class Params, class BinOp>
174+
struct exclusive_scan<Params, BinOp, 1, false>
149175
{
150176
using type_t = typename Params::type_t;
151177
using scalar_t = typename Params::scalar_t;
152178
using binop_t = typename Params::binop_t;
153-
using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
154179

155180
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
156181
{
157-
op_t op;
158-
return op(value);
182+
value = inclusive_scan<Params, BinOp, 1, false>::__call(value);
183+
// can't risk getting short-circuited, need to store to a var
184+
type_t left = glsl::subgroupShuffleUp<type_t>(value,1);
185+
// the first invocation doesn't have anything in its left so we set to the binop's identity value for exlusive scan
186+
return bool(glsl::gl_SubgroupInvocationID()) ? left:binop_t::identity;
159187
}
160188
};
161189

162-
template<class Params, bool native>
163-
struct reduction<Params, 1, native>
190+
template<class Params, class BinOp>
191+
struct reduction<Params, BinOp, 1, false>
164192
{
165193
using type_t = typename Params::type_t;
166194
using scalar_t = typename Params::scalar_t;
167195
using binop_t = typename Params::binop_t;
168-
using op_t = subgroup::impl::reduction<binop_t, native>;
169196

170197
scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
171198
{
172-
op_t op;
173-
return op(value);
199+
// take the last subgroup invocation's value for the reduction
200+
return subgroup::BroadcastLast<type_t>(inclusive_scan<Params, BinOp, 1, false>::__call(value));
174201
}
175202
};
176203

0 commit comments

Comments
 (0)