1
- // Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
1
+ // Copyright (C) 2025 - DevSH Graphics Programming Sp. z O.O.
2
2
// This file is part of the "Nabla Engine".
3
3
// For conditions of distribution and use, see copyright notice in nabla.h
4
4
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
5
5
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
6
6
7
+ // #include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
8
+ // #include "nbl/builtin/hlsl/glsl_compat/subgroup_arithmetic.hlsl"
9
+
10
+ // #include "nbl/builtin/hlsl/subgroup/ballot.hlsl"
11
+
12
+ // #include "nbl/builtin/hlsl/functional.hlsl"
13
+
7
14
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
8
15
9
16
namespace nbl
@@ -16,12 +23,12 @@ namespace subgroup2
16
23
namespace impl
17
24
{
18
25
19
- template<class Binop, typename T , uint32_t ItemsPerInvocation, bool native>
26
+ template<class Params , uint32_t ItemsPerInvocation, bool native>
20
27
struct inclusive_scan
21
28
{
22
- using type_t = T ;
23
- using scalar_t = typename Binop::type_t ;
24
- using binop_t = Binop ;
29
+ using type_t = typename Params::type_t ;
30
+ using scalar_t = typename Params::scalar_t ;
31
+ using binop_t = typename Params::binop_t ;
25
32
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
26
33
27
34
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -31,27 +38,27 @@ struct inclusive_scan
31
38
binop_t binop;
32
39
type_t retval;
33
40
retval[0 ] = value[0 ];
34
- // [unroll(ItemsPerInvocation-1) ]
41
+ [unroll]
35
42
for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
36
43
retval[i] = binop (retval[i-1 ], value[i]);
37
44
38
45
exclusive_scan_op_t op;
39
46
scalar_t exclusive = op (retval[ItemsPerInvocation-1 ]);
40
47
41
- // [unroll(ItemsPerInvocation) ]
48
+ [unroll]
42
49
for (uint32_t i = 0 ; i < ItemsPerInvocation; i++)
43
50
retval[i] = binop (retval[i], exclusive);
44
51
return retval;
45
52
}
46
53
};
47
54
48
- template<class Binop, typename T , uint32_t ItemsPerInvocation, bool native>
55
+ template<class Params , uint32_t ItemsPerInvocation, bool native>
49
56
struct exclusive_scan
50
57
{
51
- using type_t = T ;
52
- using scalar_t = typename Binop::type_t ;
53
- using binop_t = Binop ;
54
- using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T , ItemsPerInvocation, native>;
58
+ using type_t = typename Params::type_t ;
59
+ using scalar_t = typename Params::scalar_t ;
60
+ using binop_t = typename Params::binop_t ;
61
+ using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<Params , ItemsPerInvocation, native>;
55
62
56
63
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
57
64
@@ -64,19 +71,19 @@ struct exclusive_scan
64
71
65
72
type_t retval;
66
73
retval[0 ] = bool (glsl::gl_SubgroupInvocationID ()) ? left[ItemsPerInvocation-1 ] : binop_t::identity;
67
- // [unroll(ItemsPerInvocation-1) ]
74
+ [unroll]
68
75
for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
69
76
retval[i] = value[i-1 ];
70
77
return retval;
71
78
}
72
79
};
73
80
74
- template<class Binop, typename T , uint32_t ItemsPerInvocation, bool native>
81
+ template<class Params , uint32_t ItemsPerInvocation, bool native>
75
82
struct reduction
76
83
{
77
- using type_t = T; // TODO? assert scalar_type<T> == scalar_t
78
- using scalar_t = typename Binop::type_t ;
79
- using binop_t = Binop ;
84
+ using type_t = typename Params::type_t;
85
+ using scalar_t = typename Params::scalar_t ;
86
+ using binop_t = typename Params::binop_t ;
80
87
using op_t = subgroup::impl::reduction<binop_t, native>;
81
88
82
89
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
@@ -86,49 +93,81 @@ struct reduction
86
93
binop_t binop;
87
94
op_t op;
88
95
scalar_t retval = value[0 ];
89
- // [unroll(ItemsPerInvocation-1) ]
96
+ [unroll]
90
97
for (uint32_t i = 1 ; i < ItemsPerInvocation; i++)
91
98
retval = binop (retval, value[i]);
92
99
return op (retval);
93
100
}
94
101
};
95
102
96
103
97
- // spec for N=1 uses subgroup funcs
98
- template<class Binop, typename T, bool native>
99
- struct inclusive_scan<Binop, T, 1 , native>
104
+ // specs for N=1 uses subgroup funcs
105
+ // specialize native
106
+ // #define SPECIALIZE(NAME,BINOP,SUBGROUP_OP) template<typename T> struct NAME<BINOP<T>,true> \
107
+ // { \
108
+ // using type_t = T; \
109
+ // \
110
+ // type_t operator()(NBL_CONST_REF_ARG(type_t) v) {return glsl::subgroup##SUBGROUP_OP<type_t>(v);} \
111
+ // }
112
+
113
+ // #define SPECIALIZE_ALL(BINOP,SUBGROUP_OP) SPECIALIZE(reduction,BINOP,SUBGROUP_OP); \
114
+ // SPECIALIZE(inclusive_scan,BINOP,Inclusive##SUBGROUP_OP); \
115
+ // SPECIALIZE(exclusive_scan,BINOP,Exclusive##SUBGROUP_OP);
116
+
117
+ // SPECIALIZE_ALL(bit_and,And);
118
+ // SPECIALIZE_ALL(bit_or,Or);
119
+ // SPECIALIZE_ALL(bit_xor,Xor);
120
+
121
+ // SPECIALIZE_ALL(plus,Add);
122
+ // SPECIALIZE_ALL(multiplies,Mul);
123
+
124
+ // SPECIALIZE_ALL(minimum,Min);
125
+ // SPECIALIZE_ALL(maximum,Max);
126
+
127
+ // #undef SPECIALIZE_ALL
128
+ // #undef SPECIALIZE
129
+
130
+ // specialize portability
131
+ template<class Params, bool native>
132
+ struct inclusive_scan<Params, 1 , native>
100
133
{
101
- using binop_t = Binop;
134
+ using type_t = typename Params::type_t;
135
+ using scalar_t = typename Params::scalar_t;
136
+ using binop_t = typename Params::binop_t;
102
137
using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
103
138
// assert T == scalar type, binop::type == T
104
139
105
- T operator ()(NBL_CONST_REF_ARG (T ) value)
140
+ type_t operator ()(NBL_CONST_REF_ARG (type_t ) value)
106
141
{
107
142
op_t op;
108
143
return op (value);
109
144
}
110
145
};
111
146
112
- template<class Binop, typename T , bool native>
113
- struct exclusive_scan<Binop, T , 1 , native>
147
+ template<class Params , bool native>
148
+ struct exclusive_scan<Params , 1 , native>
114
149
{
115
- using binop_t = Binop;
150
+ using type_t = typename Params::type_t;
151
+ using scalar_t = typename Params::scalar_t;
152
+ using binop_t = typename Params::binop_t;
116
153
using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
117
154
118
- T operator ()(NBL_CONST_REF_ARG (T ) value)
155
+ type_t operator ()(NBL_CONST_REF_ARG (type_t ) value)
119
156
{
120
157
op_t op;
121
158
return op (value);
122
159
}
123
160
};
124
161
125
- template<class Binop, typename T , bool native>
126
- struct reduction<Binop, T , 1 , native>
162
+ template<class Params , bool native>
163
+ struct reduction<Params , 1 , native>
127
164
{
128
- using binop_t = Binop;
165
+ using type_t = typename Params::type_t;
166
+ using scalar_t = typename Params::scalar_t;
167
+ using binop_t = typename Params::binop_t;
129
168
using op_t = subgroup::impl::reduction<binop_t, native>;
130
169
131
- T operator ()(NBL_CONST_REF_ARG (T ) value)
170
+ scalar_t operator ()(NBL_CONST_REF_ARG (type_t ) value)
132
171
{
133
172
op_t op;
134
173
return op (value);
0 commit comments