Skip to content

Commit 2cce4f9

Browse files
add all test cases for workgroup scans and prepare the template definitions for them as well
1 parent 4a7dd3a commit 2cce4f9

File tree

3 files changed

+287
-83
lines changed

3 files changed

+287
-83
lines changed

examples_tests/48.ArithmeticUnitTest/main.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,36 @@ struct emulatedWorkgroupReduction
163163

164164
_IRR_STATIC_INLINE_CONSTEXPR const char* name = "workgroup reduction";
165165
};
166+
template<class OP>
167+
struct emulatedWorkgroupScanExclusive
168+
{
169+
using type_t = typename OP::type_t;
170+
171+
inline type_t operator()(const type_t* workgroupData, const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize)
172+
{
173+
type_t retval = OP::IdentityElement;
174+
for (auto i=0u; i<localInvocationIndex; i++)
175+
retval = OP()(retval,workgroupData[i]);
176+
return retval;
177+
}
178+
179+
_IRR_STATIC_INLINE_CONSTEXPR const char* name = "workgroup exclusive scan";
180+
};
181+
template<class OP>
182+
struct emulatedWorkgroupScanInclusive
183+
{
184+
using type_t = typename OP::type_t;
185+
186+
inline type_t operator()(const type_t* workgroupData, const uint32_t localInvocationIndex, uint32_t subgroupSize, uint32_t workgroupSize)
187+
{
188+
type_t retval = OP::IdentityElement;
189+
for (auto i=0u; i<=localInvocationIndex; i++)
190+
retval = OP()(retval,workgroupData[i]);
191+
return retval;
192+
}
193+
194+
_IRR_STATIC_INLINE_CONSTEXPR const char* name = "workgroup inclusive scan";
195+
};
166196

167197

168198
#include "common.glsl"
@@ -334,8 +364,8 @@ int main()
334364
getShaderGLSL("../testSubgroupExclusive.comp"),
335365
getShaderGLSL("../testSubgroupInclusive.comp"),
336366
getShaderGLSL("../testWorkgroupReduce.comp"),
337-
//getShaderGLSL("../testWorkgroupExclusive.comp"),
338-
//getShaderGLSL("../testWorkgroupInclusive.comp")
367+
getShaderGLSL("../testWorkgroupExclusive.comp"),
368+
getShaderGLSL("../testWorkgroupInclusive.comp")
339369
};
340370
constexpr auto kTestTypeCount = sizeof(shaderGLSL)/sizeof(GLSLCodeWithWorkgroup);
341371

@@ -354,7 +384,7 @@ int main()
354384
//max workgroup size is hardcoded to 1024
355385
uint32_t totalFailCount = 0;
356386
const auto ds = descriptorSet.get();
357-
for (uint32_t workgroupSize=1u; workgroupSize<=1024u; workgroupSize++)
387+
for (uint32_t workgroupSize=8u; workgroupSize<=1024u; workgroupSize++)
358388
{
359389
core::smart_refctd_ptr<IGPUComputePipeline> pipelines[kTestTypeCount];
360390
for (uint32_t i=0u; i<kTestTypeCount; i++)
@@ -364,12 +394,12 @@ int main()
364394

365395
driver->beginScene(true);
366396
const video::IGPUDescriptorSet* ds = descriptorSet.get();
367-
//passed = runTest<emulatedSubgroupReduction>(driver,pipelines[0u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
368-
//passed = runTest<emulatedSubgroupScanExclusive>(driver,pipelines[1u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
369-
//passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[2u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
397+
passed = runTest<emulatedSubgroupReduction>(driver,pipelines[0u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
398+
passed = runTest<emulatedSubgroupScanExclusive>(driver,pipelines[1u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
399+
passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[2u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
370400
passed = runTest<emulatedWorkgroupReduction>(driver,pipelines[3u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
371-
//passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[4u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
372-
//passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[5u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
401+
//passed = runTest<emulatedWorkgroupScanExclusive>(driver,pipelines[4u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
402+
//passed = runTest<emulatedWorkgroupScanInclusive>(driver,pipelines[5u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
373403

374404
if (passed)
375405
os::Printer::log("Passed test #" + std::to_string(workgroupSize), ELL_INFORMATION);
@@ -381,7 +411,7 @@ int main()
381411
driver->endScene();
382412
}
383413
os::Printer::log("==========Result==========", ELL_INFORMATION);
384-
os::Printer::log("Failed:" + totalFailCount, ELL_INFORMATION);
414+
os::Printer::log("Fail Count: " + std::to_string(totalFailCount), ELL_INFORMATION);
385415

386416
delete [] inputData;
387417
return 0;

0 commit comments

Comments
 (0)