Skip to content

Commit 0f8e7e5

Browse files
Merge pull request #24 from Devsh-Graphics-Programming/ballot-test
Ballot test
2 parents 6ac3b9a + 2b5e333 commit 0f8e7e5

File tree

7 files changed

+70
-42
lines changed

7 files changed

+70
-42
lines changed

examples_tests/48.ArithmeticUnitTest/main.cpp

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct and
1818
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = ~0ull; // this should be a reinterpret cast
1919

2020
inline T operator()(T left, T right) { return left & right; }
21-
21+
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false;
2222
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "and";
2323
};
2424
template<typename T>
@@ -28,7 +28,7 @@ struct xor
2828
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = 0ull; // this should be a reinterpret cast
2929

3030
inline T operator()(T left, T right) { return left ^ right; }
31-
31+
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false;
3232
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "xor";
3333
};
3434
template<typename T>
@@ -38,7 +38,7 @@ struct or
3838
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = 0ull; // this should be a reinterpret cast
3939

4040
inline T operator()(T left, T right) { return left | right; }
41-
41+
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false;
4242
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "or";
4343
};
4444
template<typename T>
@@ -48,7 +48,7 @@ struct add
4848
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = T(0);
4949

5050
inline T operator()(T left, T right) { return left + right; }
51-
51+
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false;
5252
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "add";
5353
};
5454
template<typename T>
@@ -58,7 +58,7 @@ struct mul
5858
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = T(1);
5959

6060
inline T operator()(T left, T right) { return left * right; }
61-
61+
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false;
6262
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "mul";
6363
};
6464
template<typename T>
@@ -68,7 +68,7 @@ struct min
6868
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = std::numeric_limits<T>::max();
6969

7070
inline T operator()(T left, T right) { return std::min<T>(left, right); }
71-
71+
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false;
7272
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "min";
7373
};
7474
template<typename T>
@@ -78,9 +78,11 @@ struct max
7878
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = std::numeric_limits<T>::lowest();
7979

8080
inline T operator()(T left, T right) { return std::max<T>(left, right); }
81-
81+
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = false;
8282
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "max";
8383
};
84+
template<typename T>
85+
struct ballot : add<T> {};
8486

8587

8688
//subgroup method emulations on the CPU, to verify the results of the GPU methods
@@ -111,7 +113,6 @@ struct emulatedSubgroupReduction : emulatedSubgroupCommon<emulatedSubgroupReduct
111113
red = OP()(red,subgroupData[i]);
112114
std::fill(outSubgroupData,outSubgroupData+clampedSubgroupSize,red);
113115
}
114-
115116
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "subgroup reduction";
116117
};
117118
template<class OP>
@@ -125,7 +126,6 @@ struct emulatedSubgroupScanExclusive : emulatedSubgroupCommon<emulatedSubgroupSc
125126
for (auto i=1u; i<clampedSubgroupSize; i++)
126127
outSubgroupData[i] = OP()(outSubgroupData[i-1u],subgroupData[i-1u]);
127128
}
128-
129129
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "subgroup exclusive scan";
130130
};
131131
template<class OP>
@@ -139,7 +139,6 @@ struct emulatedSubgroupScanInclusive : emulatedSubgroupCommon<emulatedSubgroupSc
139139
for (auto i=1u; i<clampedSubgroupSize; i++)
140140
outSubgroupData[i] = OP()(outSubgroupData[i-1u],subgroupData[i]);
141141
}
142-
143142
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "subgroup inclusive scan";
144143
};
145144

@@ -151,12 +150,11 @@ struct emulatedWorkgroupReduction
151150

152151
inline void operator()(type_t* outputData, const type_t* workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
153152
{
154-
type_t red = workgroupData[0];
153+
type_t red = OP::runOPonFirst ? OP()(0, workgroupData[0]) : workgroupData[0];
155154
for (auto i=1u; i<workgroupSize; i++)
156155
red = OP()(red,workgroupData[i]);
157156
std::fill(outputData,outputData+workgroupSize,red);
158157
}
159-
160158
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "workgroup reduction";
161159
};
162160
template<class OP>
@@ -170,7 +168,6 @@ struct emulatedWorkgroupScanExclusive
170168
for (auto i=1u; i<workgroupSize; i++)
171169
outputData[i] = OP()(outputData[i-1u],workgroupData[i-1u]);
172170
}
173-
174171
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "workgroup exclusive scan";
175172
};
176173
template<class OP>
@@ -184,14 +181,14 @@ struct emulatedWorkgroupScanInclusive
184181
for (auto i=1u; i<workgroupSize; i++)
185182
outputData[i] = OP()(outputData[i-1u],workgroupData[i]);
186183
}
187-
188184
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "workgroup inclusive scan";
189185
};
190186

191187

192188
#include "common.glsl"
193189
constexpr uint32_t kBufferSize = BUFFER_DWORD_COUNT*sizeof(uint32_t);
194190

191+
195192
//returns true if result matches
196193
template<template<class> class Arithmetic, template<class> class OP>
197194
bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, const uint32_t workgroupSize, const uint32_t workgroupCount, video::IGPUBuffer* bufferToDownload)
@@ -228,18 +225,27 @@ bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, con
228225
// now check if the data obtained has valid values
229226
constexpr uint32_t subgroupSize = 4u;
230227
uint32_t* tmp = new uint32_t[workgroupSize];
228+
uint32_t* ballotInput = new uint32_t[workgroupSize];
231229
for (uint32_t workgroupID=0u; success&&workgroupID<workgroupCount; workgroupID++)
232230
{
233231
const auto workgroupOffset = workgroupID*workgroupSize;
234-
Arithmetic<OP<uint32_t>>()(tmp,inputData+workgroupOffset,workgroupSize,subgroupSize);
232+
if constexpr (std::is_same_v<OP<uint32_t>,ballot<uint32_t>>)
233+
{
234+
for (auto i=0u; i<workgroupSize; i++)
235+
ballotInput[i] = inputData[i+workgroupOffset]&0x1u;
236+
Arithmetic<OP<uint32_t>>()(tmp,ballotInput,workgroupSize,subgroupSize);
237+
}
238+
else
239+
Arithmetic<OP<uint32_t>>()(tmp,inputData+workgroupOffset,workgroupSize,subgroupSize);
235240
for (uint32_t localInvocationIndex=0u; localInvocationIndex<workgroupSize; localInvocationIndex++)
236241
if (tmp[localInvocationIndex]!=dataFromBuffer[workgroupOffset+localInvocationIndex])
237242
{
238-
os::Printer::log("Failed test #" + std::to_string(workgroupSize) + " (" + Arithmetic<OP<uint32_t>>::name + ") (" + OP<uint32_t>::name + ")", ELL_ERROR);
243+
os::Printer::log("Failed test #" + std::to_string(workgroupSize) + " (" + Arithmetic<OP<uint32_t>>::name + ") (" + OP<uint32_t>::name + ") Expected "+ std::to_string(tmp[localInvocationIndex])+ " got " + std::to_string(dataFromBuffer[workgroupOffset + localInvocationIndex]), ELL_ERROR);
239244
success = false;
240245
break;
241246
}
242247
}
248+
delete[] ballotInput;
243249
delete[] tmp;
244250
}
245251
else
@@ -250,7 +256,7 @@ bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, con
250256

251257
}
252258
template<template<class> class Arithmetic>
253-
bool runTest(video::IVideoDriver* driver, video::IGPUComputePipeline* pipeline, const video::IGPUDescriptorSet* ds, const uint32_t* inputData, const uint32_t workgroupSize, core::smart_refctd_ptr<IGPUBuffer>* const buffers)
259+
bool runTest(video::IVideoDriver* driver, video::IGPUComputePipeline* pipeline, const video::IGPUDescriptorSet* ds, const uint32_t* inputData, const uint32_t workgroupSize, core::smart_refctd_ptr<IGPUBuffer>* const buffers, bool is_workgroup_test = false)
254260
{
255261
driver->bindComputePipeline(pipeline);
256262
driver->bindDescriptorSets(video::EPBP_COMPUTE,pipeline->getLayout(),0u,1u,&ds,nullptr);
@@ -265,6 +271,11 @@ bool runTest(video::IVideoDriver* driver, video::IGPUComputePipeline* pipeline,
265271
passed = validateResults<Arithmetic,mul>(driver, inputData, workgroupSize, workgroupCount, buffers[4].get())&&passed;
266272
passed = validateResults<Arithmetic,::min>(driver, inputData, workgroupSize, workgroupCount, buffers[5].get())&&passed;
267273
passed = validateResults<Arithmetic,::max>(driver, inputData, workgroupSize, workgroupCount, buffers[6].get())&&passed;
274+
if(is_workgroup_test)
275+
{
276+
passed = validateResults<Arithmetic,ballot>(driver, inputData, workgroupSize, workgroupCount, buffers[7].get()) && passed;
277+
}
278+
268279
return passed;
269280
}
270281

@@ -300,43 +311,41 @@ int main()
300311
}
301312
auto gpuinputDataBuffer = driver->createFilledDeviceLocalGPUBufferOnDedMem(kBufferSize, inputData);
302313

303-
//create 7 buffers.
304-
core::smart_refctd_ptr<IGPUBuffer> buffers[7];
305-
for (size_t i = 0; i < 7; i++)
314+
//create 8 buffers.
315+
constexpr const int outputBufferCount = 8;
316+
constexpr const int totalBufferCount = outputBufferCount+1;
317+
318+
core::smart_refctd_ptr<IGPUBuffer> buffers[outputBufferCount];
319+
for (size_t i = 0; i < outputBufferCount; i++)
306320
{
307321
buffers[i] = driver->createDeviceLocalGPUBufferOnDedMem(kBufferSize);
308322
}
309323

310-
IGPUDescriptorSetLayout::SBinding binding[8] = {
311-
{0u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr}, //input with randomized numbers
312-
{1u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr},
313-
{2u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr},
314-
{3u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr},
315-
{4u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr},
316-
{5u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr},
317-
{6u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr},
318-
{7u,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr},
319-
};
320-
auto gpuDSLayout = driver->createGPUDescriptorSetLayout(binding, binding + 8);
321-
constexpr uint32_t pushconstantSize = 64u;
324+
IGPUDescriptorSetLayout::SBinding binding[totalBufferCount];
325+
for (uint32_t i = 0u; i < totalBufferCount; i++)
326+
{
327+
binding[i] = { i,EDT_STORAGE_BUFFER,1u,IGPUSpecializedShader::ESS_COMPUTE,nullptr };
328+
}
329+
auto gpuDSLayout = driver->createGPUDescriptorSetLayout(binding, binding + totalBufferCount);
330+
constexpr uint32_t pushconstantSize = 8u* totalBufferCount;
322331
SPushConstantRange pcRange[1] = { IGPUSpecializedShader::ESS_COMPUTE,0u,pushconstantSize };
323332
auto pipelineLayout = driver->createGPUPipelineLayout(pcRange, pcRange + pushconstantSize, core::smart_refctd_ptr(gpuDSLayout));
324333

325334
auto descriptorSet = driver->createGPUDescriptorSet(core::smart_refctd_ptr(gpuDSLayout));
326335
{
327-
IGPUDescriptorSet::SDescriptorInfo infos[8];
336+
IGPUDescriptorSet::SDescriptorInfo infos[totalBufferCount];
328337
infos[0].desc = gpuinputDataBuffer;
329338
infos[0].buffer = { 0u,kBufferSize };
330-
for (uint32_t i=1u; i<=7u; i++)
339+
for (uint32_t i=1u; i<= outputBufferCount; i++)
331340
{
332341
infos[i].desc = buffers[i - 1];
333342
infos[i].buffer = { 0u,kBufferSize };
334343

335344
}
336-
IGPUDescriptorSet::SWriteDescriptorSet writes[8];
337-
for (uint32_t i=0u; i<8u; i++)
345+
IGPUDescriptorSet::SWriteDescriptorSet writes[totalBufferCount];
346+
for (uint32_t i=0u; i< totalBufferCount; i++)
338347
writes[i] = { descriptorSet.get(),i,0u,1u,EDT_STORAGE_BUFFER,infos + i };
339-
driver->updateDescriptorSets(8, writes, 0u, nullptr);
348+
driver->updateDescriptorSets(totalBufferCount, writes, 0u, nullptr);
340349
}
341350
struct GLSLCodeWithWorkgroup {
342351
uint32_t workgroup_definition_position;
@@ -391,9 +400,9 @@ int main()
391400
passed = runTest<emulatedSubgroupReduction>(driver,pipelines[0u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
392401
passed = runTest<emulatedSubgroupScanExclusive>(driver,pipelines[1u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
393402
passed = runTest<emulatedSubgroupScanInclusive>(driver,pipelines[2u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
394-
passed = runTest<emulatedWorkgroupReduction>(driver,pipelines[3u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
395-
passed = runTest<emulatedWorkgroupScanExclusive>(driver,pipelines[4u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
396-
passed = runTest<emulatedWorkgroupScanInclusive>(driver,pipelines[5u].get(),descriptorSet.get(),inputData,workgroupSize,buffers)&&passed;
403+
passed = runTest<emulatedWorkgroupReduction>(driver,pipelines[3u].get(),descriptorSet.get(),inputData,workgroupSize,buffers,true)&&passed;
404+
passed = runTest<emulatedWorkgroupScanExclusive>(driver,pipelines[4u].get(),descriptorSet.get(),inputData,workgroupSize,buffers, true)&&passed;
405+
passed = runTest<emulatedWorkgroupScanInclusive>(driver,pipelines[5u].get(),descriptorSet.get(),inputData,workgroupSize,buffers, true)&&passed;
397406

398407
if (passed)
399408
os::Printer::log("Passed test #" + std::to_string(workgroupSize), ELL_INFORMATION);

examples_tests/48.ArithmeticUnitTest/shaderCommon.glsl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,8 @@ layout(set = 0, binding = 6, std430) writeonly buffer outmin
3333
layout(set = 0, binding = 7, std430) writeonly buffer outmax
3434
{
3535
uint maxOutput[];
36+
};
37+
layout(set = 0, binding = 8, std430) writeonly buffer outbitcount
38+
{
39+
uint bitCountOutput[];
3640
};

examples_tests/48.ArithmeticUnitTest/testWorkgroupExclusive.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ void main()
1313
multOutput[gl_GlobalInvocationID.x] = nbl_glsl_workgroupExclusiveMul(sourceVal);
1414
minOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupExclusiveMin(sourceVal);
1515
maxOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupExclusiveMax(sourceVal);
16+
nbl_glsl_workgroupBallot((sourceVal&0x1u)==0x1u);
17+
bitCountOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupBallotExclusiveBitCount();
18+
1619
}

examples_tests/48.ArithmeticUnitTest/testWorkgroupInclusive.comp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ void main()
1313
multOutput[gl_GlobalInvocationID.x] = nbl_glsl_workgroupInclusiveMul(sourceVal);
1414
minOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupInclusiveMin(sourceVal);
1515
maxOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupInclusiveMax(sourceVal);
16+
nbl_glsl_workgroupBallot((sourceVal&0x1u)==0x1u);
17+
bitCountOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupBallotInclusiveBitCount();
1618
}

examples_tests/48.ArithmeticUnitTest/testWorkgroupReduce.comp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ void main()
1313
multOutput[gl_GlobalInvocationID.x] = nbl_glsl_workgroupMul(sourceVal);
1414
minOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupMin(sourceVal);
1515
maxOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupMax(sourceVal);
16+
nbl_glsl_workgroupBallot((sourceVal&0x1u)==0x1u);
17+
bitCountOutput [gl_GlobalInvocationID.x] = nbl_glsl_workgroupBallotBitCount();
18+
1619
}
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
11
#include "shaderCommon.glsl"
22

3-
#include "nbl/builtin/glsl/workgroup/arithmetic.glsl"
3+
// ORDER OF INCLUDES MATTERS !!!!!
4+
// first the feature that requires the most shared memory should be included
5+
// anyway when one is using more than 2 features that rely on shared memory,
6+
// they should declare the shared memory of appropriate size by themselves.
7+
// But in this unit test we don't because we need to test if the default
8+
// sizing macros actually work for all workgroup sizes.
9+
#include <nbl/builtin/glsl/workgroup/arithmetic.glsl>
10+
#include <nbl/builtin/glsl/workgroup/ballot.glsl>

include/nbl/builtin/glsl/workgroup/ballot.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ uint nbl_glsl_workgroupBallotScanBitCount_impl(in bool exclusive)
273273
barrier();
274274
}
275275

276-
const uint mask = 0xffffffffu>>((exclusive ? 32u:31u)-(gl_LocalInvocationIndex&31u));
276+
const uint mask = (exclusive ? 0x7fffffffu:0xffffffffu)>>(31u-(gl_LocalInvocationIndex&31u));
277277
return globalCount+bitCount(localBitfield&mask);
278278
}
279279

0 commit comments

Comments
 (0)