4
4
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
5
5
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
6
6
7
- #include "nbl/builtin/hlsl/subgroup/arithmetic_portability .hlsl"
7
+ #include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl .hlsl"
8
8
9
9
namespace nbl
10
10
{
@@ -16,15 +16,15 @@ namespace subgroup2
16
16
namespace impl
17
17
{
18
18
19
- template<class Binop, typename T, bool native>
19
+ template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
20
20
struct inclusive_scan
21
21
{
22
22
using type_t = T;
23
23
using scalar_t = typename Binop::type_t;
24
24
using binop_t = Binop;
25
25
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
26
26
27
- NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
27
+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
28
28
29
29
type_t operator ()(NBL_CONST_REF_ARG (type_t) value)
30
30
{
@@ -45,15 +45,15 @@ struct inclusive_scan
45
45
}
46
46
};
47
47
48
- template<class Binop, typename T, bool native>
48
+ template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
49
49
struct exclusive_scan
50
50
{
51
51
using type_t = T;
52
52
using scalar_t = typename Binop::type_t;
53
53
using binop_t = Binop;
54
- using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, native>;
54
+ using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, ItemsPerInvocation, native>;
55
55
56
- NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
56
+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
57
57
58
58
type_t operator ()(type_t value)
59
59
{
@@ -71,15 +71,15 @@ struct exclusive_scan
71
71
}
72
72
};
73
73
74
- template<class Binop, typename T, bool native>
74
+ template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
75
75
struct reduction
76
76
{
77
77
using type_t = T; // TODO? assert scalar_type<T> == scalar_t
78
78
using scalar_t = typename Binop::type_t;
79
79
using binop_t = Binop;
80
80
using op_t = subgroup::impl::reduction<binop_t, native>;
81
81
82
- NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
82
+ // NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
83
83
84
84
scalar_t operator ()(NBL_CONST_REF_ARG (type_t) value)
85
85
{
@@ -93,6 +93,48 @@ struct reduction
93
93
}
94
94
};
95
95
96
+
97
+ // spec for N=1 uses subgroup funcs
98
+ template<class Binop, typename T, bool native>
99
+ struct inclusive_scan<Binop, T, 1 , native>
100
+ {
101
+ using binop_t = Binop;
102
+ using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
103
+ // assert T == scalar type, binop::type == T
104
+
105
+ T operator ()(NBL_CONST_REF_ARG (T) value)
106
+ {
107
+ op_t op;
108
+ return op (value);
109
+ }
110
+ };
111
+
112
+ template<class Binop, typename T, bool native>
113
+ struct exclusive_scan<Binop, T, 1 , native>
114
+ {
115
+ using binop_t = Binop;
116
+ using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
117
+
118
+ T operator ()(NBL_CONST_REF_ARG (T) value)
119
+ {
120
+ op_t op;
121
+ return op (value);
122
+ }
123
+ };
124
+
125
+ template<class Binop, typename T, bool native>
126
+ struct reduction<Binop, T, 1 , native>
127
+ {
128
+ using binop_t = Binop;
129
+ using op_t = subgroup::impl::reduction<binop_t, native>;
130
+
131
+ T operator ()(NBL_CONST_REF_ARG (T) value)
132
+ {
133
+ op_t op;
134
+ return op (value);
135
+ }
136
+ };
137
+
96
138
}
97
139
98
140
}
0 commit comments