@@ -82,15 +82,7 @@ struct max
82
82
_NBL_STATIC_INLINE_CONSTEXPR const char * name = " max" ;
83
83
};
84
84
template <typename T>
85
- struct countBits
86
- {
87
- using type_t = T;
88
- _NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = T(0 );
89
-
90
- inline T operator ()(T left, T right) { return left + (right&1u ); }
91
- _NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = true ;
92
- _NBL_STATIC_INLINE_CONSTEXPR const char * name = " bitcount" ;
93
- };
85
+ struct ballot : add<T> {};
94
86
95
87
96
88
// subgroup method emulations on the CPU, to verify the results of the GPU methods
@@ -172,7 +164,7 @@ struct emulatedWorkgroupScanExclusive
172
164
173
165
inline void operator ()(type_t * outputData, const type_t * workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
174
166
{
175
- outputData[0u ] = OP::runOPonFirst ? OP ()( 0 , workgroupData[ 0 ]) : OP:: IdentityElement;
167
+ outputData[0u ] = OP::IdentityElement;
176
168
for (auto i=1u ; i<workgroupSize; i++)
177
169
outputData[i] = OP ()(outputData[i-1u ],workgroupData[i-1u ]);
178
170
}
@@ -185,7 +177,7 @@ struct emulatedWorkgroupScanInclusive
185
177
186
178
inline void operator ()(type_t * outputData, const type_t * workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
187
179
{
188
- outputData[0u ] = OP::runOPonFirst ? OP ()( 0 , workgroupData[ 0 ]) : workgroupData[0u ];
180
+ outputData[0u ] = workgroupData[0u ];
189
181
for (auto i=1u ; i<workgroupSize; i++)
190
182
outputData[i] = OP ()(outputData[i-1u ],workgroupData[i]);
191
183
}
@@ -196,6 +188,7 @@ struct emulatedWorkgroupScanInclusive
196
188
#include " common.glsl"
197
189
constexpr uint32_t kBufferSize = BUFFER_DWORD_COUNT*sizeof (uint32_t );
198
190
191
+
199
192
// returns true if result matches
200
193
template <template <class > class Arithmetic , template <class > class OP >
201
194
bool validateResults (video::IVideoDriver* driver, const uint32_t * inputData, const uint32_t workgroupSize, const uint32_t workgroupCount, video::IGPUBuffer* bufferToDownload)
@@ -232,18 +225,27 @@ bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, con
232
225
// now check if the data obtained has valid values
233
226
constexpr uint32_t subgroupSize = 4u ;
234
227
uint32_t * tmp = new uint32_t [workgroupSize];
228
+ uint32_t * ballotInput = new uint32_t [workgroupSize];
235
229
for (uint32_t workgroupID=0u ; success&&workgroupID<workgroupCount; workgroupID++)
236
230
{
237
231
const auto workgroupOffset = workgroupID*workgroupSize;
238
- Arithmetic<OP<uint32_t >>()(tmp,inputData+workgroupOffset,workgroupSize,subgroupSize);
232
+ if constexpr (std::is_same_v<OP<uint32_t >,ballot<uint32_t >>)
233
+ {
234
+ for (auto i=0u ; i<workgroupSize; i++)
235
+ ballotInput[i] = inputData[i+workgroupOffset]&0x1u ;
236
+ Arithmetic<OP<uint32_t >>()(tmp,ballotInput,workgroupSize,subgroupSize);
237
+ }
238
+ else
239
+ Arithmetic<OP<uint32_t >>()(tmp,inputData+workgroupOffset,workgroupSize,subgroupSize);
239
240
for (uint32_t localInvocationIndex=0u ; localInvocationIndex<workgroupSize; localInvocationIndex++)
240
241
if (tmp[localInvocationIndex]!=dataFromBuffer[workgroupOffset+localInvocationIndex])
241
242
{
242
- os::Printer::log (" Failed test #" + std::to_string (workgroupSize) + " (" + Arithmetic<OP<uint32_t >>::name + " ) (" + OP<uint32_t >::name + " ) Expected " + std::to_string (dataFromBuffer[workgroupOffset + localInvocationIndex])+ " got " + std::to_string (tmp[ localInvocationIndex]), ELL_ERROR);
243
+ os::Printer::log (" Failed test #" + std::to_string (workgroupSize) + " (" + Arithmetic<OP<uint32_t >>::name + " ) (" + OP<uint32_t >::name + " ) Expected " + std::to_string (tmp[ localInvocationIndex])+ " got " + std::to_string (dataFromBuffer[workgroupOffset + localInvocationIndex]), ELL_ERROR);
243
244
success = false ;
244
245
break ;
245
246
}
246
247
}
248
+ delete[] ballotInput;
247
249
delete[] tmp;
248
250
}
249
251
else
@@ -271,7 +273,7 @@ bool runTest(video::IVideoDriver* driver, video::IGPUComputePipeline* pipeline,
271
273
passed = validateResults<Arithmetic,::max>(driver, inputData, workgroupSize, workgroupCount, buffers[6 ].get ())&&passed;
272
274
if (is_workgroup_test)
273
275
{
274
- passed = validateResults<Arithmetic, countBits >(driver, inputData, workgroupSize, workgroupCount, buffers[7 ].get ()) && passed;
276
+ passed = validateResults<Arithmetic,ballot >(driver, inputData, workgroupSize, workgroupCount, buffers[7 ].get ()) && passed;
275
277
}
276
278
277
279
return passed;
0 commit comments