@@ -18,7 +18,7 @@ struct and
18
18
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = ~0ull ; // this should be a reinterpret cast
19
19
20
20
inline T operator ()(T left, T right) { return left & right; }
21
-
21
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false ;
22
22
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " and" ;
23
23
};
24
24
template <typename T>
@@ -28,7 +28,7 @@ struct xor
28
28
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = 0ull ; // this should be a reinterpret cast
29
29
30
30
inline T operator ()(T left, T right) { return left ^ right; }
31
-
31
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false ;
32
32
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " xor" ;
33
33
};
34
34
template <typename T>
@@ -38,7 +38,7 @@ struct or
38
38
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = 0ull ; // this should be a reinterpret cast
39
39
40
40
inline T operator ()(T left, T right) { return left | right; }
41
-
41
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false ;
42
42
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " or" ;
43
43
};
44
44
template <typename T>
@@ -48,7 +48,7 @@ struct add
48
48
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = T(0 );
49
49
50
50
inline T operator ()(T left, T right) { return left + right; }
51
-
51
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false ;
52
52
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " add" ;
53
53
};
54
54
template <typename T>
@@ -58,7 +58,7 @@ struct mul
58
58
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = T(1 );
59
59
60
60
inline T operator ()(T left, T right) { return left * right; }
61
-
61
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false ;
62
62
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " mul" ;
63
63
};
64
64
template <typename T>
@@ -68,7 +68,7 @@ struct min
68
68
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = std::numeric_limits<T>::max();
69
69
70
70
inline T operator ()(T left, T right) { return std::min<T>(left, right); }
71
-
71
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false ;
72
72
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " min" ;
73
73
};
74
74
template <typename T>
@@ -78,22 +78,21 @@ struct max
78
78
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = std::numeric_limits<T>::lowest();
79
79
80
80
inline T operator ()(T left, T right) { return std::max<T>(left, right); }
81
-
81
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false ;
82
82
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " max" ;
83
83
};
84
84
template <typename T>
85
- struct bitcount
85
+ struct countBits
86
86
{
87
87
using type_t = T;
88
88
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = T(0 );
89
89
90
- inline T operator ()(T left, T right) { return T ( 0 ); }
91
-
90
+ inline T operator ()(T left, T right) { return left + (right& 1u ); }
91
+ _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = true ;
92
92
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " bitcount" ;
93
93
};
94
94
95
95
96
-
97
96
// subgroup method emulations on the CPU, to verify the results of the GPU methods
98
97
template <class CRTP , typename T>
99
98
struct emulatedSubgroupCommon
@@ -122,7 +121,6 @@ struct emulatedSubgroupReduction : emulatedSubgroupCommon<emulatedSubgroupReduct
122
121
red = OP ()(red,subgroupData[i]);
123
122
std::fill (outSubgroupData,outSubgroupData+clampedSubgroupSize,red);
124
123
}
125
-
126
124
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " subgroup reduction" ;
127
125
};
128
126
template <class OP >
@@ -136,7 +134,6 @@ struct emulatedSubgroupScanExclusive : emulatedSubgroupCommon<emulatedSubgroupSc
136
134
for (auto i=1u ; i<clampedSubgroupSize; i++)
137
135
outSubgroupData[i] = OP ()(outSubgroupData[i-1u ],subgroupData[i-1u ]);
138
136
}
139
-
140
137
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " subgroup exclusive scan" ;
141
138
};
142
139
template <class OP >
@@ -150,7 +147,6 @@ struct emulatedSubgroupScanInclusive : emulatedSubgroupCommon<emulatedSubgroupSc
150
147
for (auto i=1u ; i<clampedSubgroupSize; i++)
151
148
outSubgroupData[i] = OP ()(outSubgroupData[i-1u ],subgroupData[i]);
152
149
}
153
-
154
150
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " subgroup inclusive scan" ;
155
151
};
156
152
@@ -162,12 +158,11 @@ struct emulatedWorkgroupReduction
162
158
163
159
inline void operator ()(type_t * outputData, const type_t * workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
164
160
{
165
- type_t red = workgroupData[0 ];
161
+ type_t red = OP::runOPonFirst ? OP ()( 0 , workgroupData[ 0 ]) : workgroupData[0 ];
166
162
for (auto i=1u ; i<workgroupSize; i++)
167
163
red = OP ()(red,workgroupData[i]);
168
164
std::fill (outputData,outputData+workgroupSize,red);
169
165
}
170
-
171
166
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " workgroup reduction" ;
172
167
};
173
168
template <class OP >
@@ -177,11 +172,10 @@ struct emulatedWorkgroupScanExclusive
177
172
178
173
inline void operator ()(type_t * outputData, const type_t * workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
179
174
{
180
- outputData[0u ] = OP::IdentityElement;
175
+ outputData[0u ] = OP::runOPonFirst ? OP ()( 0 , workgroupData[ 0 ]) : OP:: IdentityElement;
181
176
for (auto i=1u ; i<workgroupSize; i++)
182
177
outputData[i] = OP ()(outputData[i-1u ],workgroupData[i-1u ]);
183
178
}
184
-
185
179
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " workgroup exclusive scan" ;
186
180
};
187
181
template <class OP >
@@ -191,11 +185,10 @@ struct emulatedWorkgroupScanInclusive
191
185
192
186
inline void operator ()(type_t * outputData, const type_t * workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
193
187
{
194
- outputData[0u ] = workgroupData[0u ];
188
+ outputData[0u ] = OP::runOPonFirst ? OP ()( 0 , workgroupData[ 0 ]) : workgroupData[0u ];
195
189
for (auto i=1u ; i<workgroupSize; i++)
196
190
outputData[i] = OP ()(outputData[i-1u ],workgroupData[i]);
197
191
}
198
-
199
192
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " workgroup inclusive scan" ;
200
193
};
201
194
@@ -246,7 +239,7 @@ bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, con
246
239
for (uint32_t localInvocationIndex=0u ; localInvocationIndex<workgroupSize; localInvocationIndex++)
247
240
if (tmp[localInvocationIndex]!=dataFromBuffer[workgroupOffset+localInvocationIndex])
248
241
{
249
- os::Printer::log (" Failed test #" + std::to_string (workgroupSize) + " (" + Arithmetic<OP<uint32_t >>::name + " ) (" + OP<uint32_t >::name + " )" , ELL_ERROR);
242
+ os::Printer::log (" Failed test #" + std::to_string (workgroupSize) + " (" + Arithmetic<OP<uint32_t >>::name + " ) (" + OP<uint32_t >::name + " ) Expected " + std::to_string (dataFromBuffer[workgroupOffset + localInvocationIndex])+ " got " + std::to_string (tmp[localInvocationIndex]) , ELL_ERROR);
250
243
success = false ;
251
244
break ;
252
245
}
@@ -277,7 +270,9 @@ bool runTest(video::IVideoDriver* driver, video::IGPUComputePipeline* pipeline,
277
270
passed = validateResults<Arithmetic,::min>(driver, inputData, workgroupSize, workgroupCount, buffers[5 ].get ())&&passed;
278
271
passed = validateResults<Arithmetic,::max>(driver, inputData, workgroupSize, workgroupCount, buffers[6 ].get ())&&passed;
279
272
if (is_workgroup_test)
280
- passed = validateResults<Arithmetic,bitcount>(driver, inputData, workgroupSize, workgroupCount, buffers[7 ].get ()) && passed;
273
+ {
274
+ passed = validateResults<Arithmetic, countBits>(driver, inputData, workgroupSize, workgroupCount, buffers[7 ].get ()) && passed;
275
+ }
281
276
282
277
return passed;
283
278
}
0 commit comments