Skip to content

Commit a8e02a3

Browse files
committed
changes to Params, Config handling types
1 parent e88f51a commit a8e02a3

File tree

5 files changed

+117
-57
lines changed

5 files changed

+117
-57
lines changed

include/nbl/builtin/hlsl/subgroup/ballot.hlsl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,6 @@ uint32_t ElectedSubgroupInvocationID() {
3737
return glsl::subgroupBroadcastFirst<uint32_t>(glsl::gl_SubgroupInvocationID());
3838
}
3939

40-
template<uint32_t SubgroupSizeLog2>
41-
struct Configuration
42-
{
43-
using mask_t = conditional_t<SubgroupSizeLog2 < 7, conditional_t<SubgroupSizeLog2 < 6, uint32_t1, uint32_t2>, uint32_t4>;
44-
45-
NBL_CONSTEXPR_STATIC_INLINE uint16_t Size = 0x1u << SubgroupSizeLog2;
46-
};
47-
48-
template<class T>
49-
struct is_configuration : bool_constant<false> {};
50-
51-
template<uint32_t N>
52-
struct is_configuration<Configuration<N> > : bool_constant<true> {};
53-
54-
template<typename T>
55-
NBL_CONSTEXPR bool is_configuration_v = is_configuration<T>::value;
56-
5740
}
5841
}
5942
}

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
22
// This file is part of the "Nabla Engine".
33
// For conditions of distribution and use, see copyright notice in nabla.h
44
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
@@ -7,7 +7,7 @@
77

88
#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
99

10-
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
10+
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
1111
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl"
1212
#include "nbl/builtin/hlsl/concepts.hlsl"
1313

@@ -19,24 +19,26 @@ namespace hlsl
1919
namespace subgroup2
2020
{
2121

22-
template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(subgroup::is_configuration_v<Config>)
22+
template<typename Config, class BinOp, int32_t _ItemsPerInvocation=1, class device_capabilities=void, bool OverrideUseNativeInstrinsics=true NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
2323
struct ArithmeticParams
2424
{
2525
using config_t = Config;
2626
using binop_t = BinOp;
2727
using scalar_t = typename BinOp::type_t; // BinOp should be with scalar type
28-
using type_t = conditional_t<_ItemsPerInvocation<2, scalar_t, vector<scalar_t, _ItemsPerInvocation> >;
28+
using type_t = vector<scalar_t, _ItemsPerInvocation>;// conditional_t<_ItemsPerInvocation<2, scalar_t, vector<scalar_t, _ItemsPerInvocation> >;
2929

3030
NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
31-
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
31+
// if OverrideUseNativeInstrinsics is true, tries to use native spirv intrinsics
32+
// if OverrideUseNativeInstrinsics is false, will always use emulated versions
33+
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic && OverrideUseNativeInstrinsics /*&& /*some heuristic for when its faster*/;
3234
};
3335

3436
template<typename Params>
35-
struct reduction : impl::reduction<typename Params::binop_t,typename Params::type_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
37+
struct reduction : impl::reduction<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
3638
template<typename Params>
37-
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,typename Params::type_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
39+
struct inclusive_scan : impl::inclusive_scan<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
3840
template<typename Params>
39-
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,typename Params::type_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
41+
struct exclusive_scan : impl::exclusive_scan<Params,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
4042

4143
}
4244
}

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
22
// This file is part of the "Nabla Engine".
33
// For conditions of distribution and use, see copyright notice in nabla.h
44
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
66

7+
// #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8+
// #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
9+
10+
// #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
11+
12+
// #include "nbl/builtin/hlsl/functional.hlsl"
13+
714
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
815

916
namespace nbl
@@ -16,12 +23,12 @@ namespace subgroup2
1623
namespace impl
1724
{
1825

19-
template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
26+
template<class Params, uint32_t ItemsPerInvocation, bool native>
2027
struct inclusive_scan
2128
{
22-
using type_t = T;
23-
using scalar_t = typename Binop::type_t;
24-
using binop_t = Binop;
29+
using type_t = typename Params::type_t;
30+
using scalar_t = typename Params::scalar_t;
31+
using binop_t = typename Params::binop_t;
2532
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
2633

2734
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -31,27 +38,27 @@ struct inclusive_scan
3138
binop_t binop;
3239
type_t retval;
3340
retval[0] = value[0];
34-
//[unroll(ItemsPerInvocation-1)]
41+
[unroll]
3542
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
3643
retval[i] = binop(retval[i-1], value[i]);
3744

3845
exclusive_scan_op_t op;
3946
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);
4047

41-
//[unroll(ItemsPerInvocation)]
48+
[unroll]
4249
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
4350
retval[i] = binop(retval[i], exclusive);
4451
return retval;
4552
}
4653
};
4754

48-
template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
55+
template<class Params, uint32_t ItemsPerInvocation, bool native>
4956
struct exclusive_scan
5057
{
51-
using type_t = T;
52-
using scalar_t = typename Binop::type_t;
53-
using binop_t = Binop;
54-
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, ItemsPerInvocation, native>;
58+
using type_t = typename Params::type_t;
59+
using scalar_t = typename Params::scalar_t;
60+
using binop_t = typename Params::binop_t;
61+
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params, ItemsPerInvocation, native>;
5562

5663
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
5764

@@ -64,19 +71,19 @@ struct exclusive_scan
6471

6572
type_t retval;
6673
retval[0] = bool(glsl::gl_SubgroupInvocationID()) ? left[ItemsPerInvocation-1] : binop_t::identity;
67-
//[unroll(ItemsPerInvocation-1)]
74+
[unroll]
6875
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
6976
retval[i] = value[i-1];
7077
return retval;
7178
}
7279
};
7380

74-
template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
81+
template<class Params, uint32_t ItemsPerInvocation, bool native>
7582
struct reduction
7683
{
77-
using type_t = T; // TODO? assert scalar_type<T> == scalar_t
78-
using scalar_t = typename Binop::type_t;
79-
using binop_t = Binop;
84+
using type_t = typename Params::type_t;
85+
using scalar_t = typename Params::scalar_t;
86+
using binop_t = typename Params::binop_t;
8087
using op_t = subgroup::impl::reduction<binop_t, native>;
8188

8289
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -86,49 +93,81 @@ struct reduction
8693
binop_t binop;
8794
op_t op;
8895
scalar_t retval = value[0];
89-
//[unroll(ItemsPerInvocation-1)]
96+
[unroll]
9097
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
9198
retval = binop(retval, value[i]);
9299
return op(retval);
93100
}
94101
};
95102

96103

97-
// spec for N=1 uses subgroup funcs
98-
template<class Binop, typename T, bool native>
99-
struct inclusive_scan<Binop, T, 1, native>
104+
// specs for N=1 uses subgroup funcs
105+
// specialize native
106+
// #define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<typename T> struct NAME<BINOP<T>,true> \
107+
// { \
108+
// using type_t = T; \
109+
// \
110+
// type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
111+
// }
112+
113+
// #define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
114+
// SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
115+
// SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116+
117+
// SPECIALIZE_ALL(bit_and,And);
118+
// SPECIALIZE_ALL(bit_or,Or);
119+
// SPECIALIZE_ALL(bit_xor,Xor);
120+
121+
// SPECIALIZE_ALL(plus,Add);
122+
// SPECIALIZE_ALL(multiplies,Mul);
123+
124+
// SPECIALIZE_ALL(minimum,Min);
125+
// SPECIALIZE_ALL(maximum,Max);
126+
127+
// #undef SPECIALIZE_ALL
128+
// #undef SPECIALIZE
129+
130+
// specialize portability
131+
template<class Params, bool native>
132+
struct inclusive_scan<Params, 1, native>
100133
{
101-
using binop_t = Binop;
134+
using type_t = typename Params::type_t;
135+
using scalar_t = typename Params::scalar_t;
136+
using binop_t = typename Params::binop_t;
102137
using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
103138
// assert T == scalar type, binop::type == T
104139

105-
T operator()(NBL_CONST_REF_ARG(T) value)
140+
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
106141
{
107142
op_t op;
108143
return op(value);
109144
}
110145
};
111146

112-
template<class Binop, typename T, bool native>
113-
struct exclusive_scan<Binop, T, 1, native>
147+
template<class Params, bool native>
148+
struct exclusive_scan<Params, 1, native>
114149
{
115-
using binop_t = Binop;
150+
using type_t = typename Params::type_t;
151+
using scalar_t = typename Params::scalar_t;
152+
using binop_t = typename Params::binop_t;
116153
using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
117154

118-
T operator()(NBL_CONST_REF_ARG(T) value)
155+
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
119156
{
120157
op_t op;
121158
return op(value);
122159
}
123160
};
124161

125-
template<class Binop, typename T, bool native>
126-
struct reduction<Binop, T, 1, native>
162+
template<class Params, bool native>
163+
struct reduction<Params, 1, native>
127164
{
128-
using binop_t = Binop;
165+
using type_t = typename Params::type_t;
166+
using scalar_t = typename Params::scalar_t;
167+
using binop_t = typename Params::binop_t;
129168
using op_t = subgroup::impl::reduction<binop_t, native>;
130169

131-
T operator()(NBL_CONST_REF_ARG(T) value)
170+
scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
132171
{
133172
op_t op;
134173
return op(value);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_SUBGROUP2_BALLOT_INCLUDED_
6+
7+
namespace nbl
8+
{
9+
namespace hlsl
10+
{
11+
namespace subgroup2
12+
{
13+
14+
template<uint32_t SubgroupSizeLog2>
15+
struct Configuration
16+
{
17+
using mask_t = conditional_t<SubgroupSizeLog2 < 7, conditional_t<SubgroupSizeLog2 < 6, uint32_t1, uint32_t2>, uint32_t4>;
18+
19+
NBL_CONSTEXPR_STATIC_INLINE uint16_t SizeLog2 = uint16_t(SubgroupSizeLog2);
20+
NBL_CONSTEXPR_STATIC_INLINE uint16_t Size = uint16_t(0x1u) << SubgroupSizeLog2;
21+
};
22+
23+
template<class T>
24+
struct is_configuration : bool_constant<false> {};
25+
26+
template<uint32_t N>
27+
struct is_configuration<Configuration<N> > : bool_constant<true> {};
28+
29+
template<typename T>
30+
NBL_CONSTEXPR bool is_configuration_v = is_configuration<T>::value;
31+
32+
}
33+
}
34+
}
35+
36+
#endif

0 commit comments

Comments
 (0)