@@ -163,6 +163,36 @@ struct emulatedWorkgroupReduction
163
163
164
164
_IRR_STATIC_INLINE_CONSTEXPR const char * name = " workgroup reduction" ;
165
165
};
166
+ template <class OP >
167
+ struct emulatedWorkgroupScanExclusive
168
+ {
169
+ using type_t = typename OP::type_t ;
170
+
171
+ inline type_t operator ()(const type_t * workgroupData, const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize)
172
+ {
173
+ type_t retval = OP::IdentityElement;
174
+ for (auto i=0u ; i<localInvocationIndex; i++)
175
+ retval = OP ()(retval,workgroupData[i]);
176
+ return retval;
177
+ }
178
+
179
+ _IRR_STATIC_INLINE_CONSTEXPR const char * name = " workgroup exclusive scan" ;
180
+ };
181
+ template <class OP >
182
+ struct emulatedWorkgroupScanInclusive
183
+ {
184
+ using type_t = typename OP::type_t ;
185
+
186
+ inline type_t operator ()(const type_t * workgroupData, const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize)
187
+ {
188
+ type_t retval = OP::IdentityElement;
189
+ for (auto i=0u ; i<=localInvocationIndex; i++)
190
+ retval = OP ()(retval,workgroupData[i]);
191
+ return retval;
192
+ }
193
+
194
+ _IRR_STATIC_INLINE_CONSTEXPR const char * name = " workgroup inclusive scan" ;
195
+ };
166
196
167
197
168
198
#include " common.glsl"
@@ -334,8 +364,8 @@ int main()
334
364
getShaderGLSL (" ../testSubgroupExclusive.comp" ),
335
365
getShaderGLSL (" ../testSubgroupInclusive.comp" ),
336
366
getShaderGLSL (" ../testWorkgroupReduce.comp" ),
337
- // getShaderGLSL("../testWorkgroupExclusive.comp"),
338
- // getShaderGLSL("../testWorkgroupInclusive.comp")
367
+ getShaderGLSL (" ../testWorkgroupExclusive.comp" ),
368
+ getShaderGLSL (" ../testWorkgroupInclusive.comp" )
339
369
};
340
370
constexpr auto kTestTypeCount = sizeof (shaderGLSL)/sizeof (GLSLCodeWithWorkgroup);
341
371
@@ -354,7 +384,7 @@ int main()
354
384
// max workgroup size is hardcoded to 1024
355
385
uint32_t totalFailCount = 0 ;
356
386
const auto ds = descriptorSet.get ();
357
- for (uint32_t workgroupSize=1u ; workgroupSize<=1024u ; workgroupSize++)
387
+ for (uint32_t workgroupSize=8u ; workgroupSize<=1024u ; workgroupSize++)
358
388
{
359
389
core::smart_refctd_ptr<IGPUComputePipeline> pipelines[kTestTypeCount ];
360
390
for (uint32_t i=0u ; i<kTestTypeCount ; i++)
@@ -364,12 +394,12 @@ int main()
364
394
365
395
driver->beginScene (true );
366
396
const video::IGPUDescriptorSet* ds = descriptorSet.get ();
367
- // passed = runTest<emulatedSubgroupReduction>(driver,pipelines[0u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
368
- // passed = runTest<emulatedSubgroupScanExclusive>(driver,pipelines[1u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
369
- // passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[2u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
397
+ passed = runTest<emulatedSubgroupReduction>(driver,pipelines[0u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
398
+ passed = runTest<emulatedSubgroupScanExclusive>(driver,pipelines[1u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
399
+ passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[2u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
370
400
passed = runTest<emulatedWorkgroupReduction>(driver,pipelines[3u ].get (),descriptorSet.get (),inputData,workgroupSize,buffers)&&passed;
371
- // passed = runTest<emulatedSubgroupScanInclusive >(driver,pipelines[4u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
372
- // passed = runTest<emulatedSubgroupScanInclusive >(driver,pipelines[5u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
401
+ // passed = runTest<emulatedWorkgroupScanExclusive >(driver,pipelines[4u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
402
+ // passed = runTest<emulatedWorkgroupScanInclusive >(driver,pipelines[5u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
373
403
374
404
if (passed)
375
405
os::Printer::log (" Passed test #" + std::to_string (workgroupSize), ELL_INFORMATION);
@@ -381,7 +411,7 @@ int main()
381
411
driver->endScene ();
382
412
}
383
413
os::Printer::log (" ==========Result==========" , ELL_INFORMATION);
384
- os::Printer::log (" Failed: " + totalFailCount, ELL_INFORMATION);
414
+ os::Printer::log (" Fail Count: " + std::to_string ( totalFailCount) , ELL_INFORMATION);
385
415
386
416
delete [] inputData;
387
417
return 0 ;
0 commit comments