Skip to content

Commit 2b5e333

Browse files
committed
fixed exclusive ballot bit count
1 parent 938f790 commit 2b5e333

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

examples_tests/48.ArithmeticUnitTest/main.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,7 @@ struct max
8282
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "max";
8383
};
8484
template<typename T>
85-
struct countBits
86-
{
87-
using type_t = T;
88-
_NBL_STATIC_INLINE_CONSTEXPR T IdentityElement = T(0);
89-
90-
inline T operator()(T left, T right) { return left + (right&1u); }
91-
_NBL_STATIC_INLINE_CONSTEXPR bool runOPonFirst = true;
92-
_NBL_STATIC_INLINE_CONSTEXPR const char* name = "bitcount";
93-
};
85+
struct ballot : add<T> {};
9486

9587

9688
//subgroup method emulations on the CPU, to verify the results of the GPU methods
@@ -172,7 +164,7 @@ struct emulatedWorkgroupScanExclusive
172164

173165
inline void operator()(type_t* outputData, const type_t* workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
174166
{
175-
outputData[0u] = OP::runOPonFirst ? OP()(0, workgroupData[0]) : OP::IdentityElement;
167+
outputData[0u] = OP::IdentityElement;
176168
for (auto i=1u; i<workgroupSize; i++)
177169
outputData[i] = OP()(outputData[i-1u],workgroupData[i-1u]);
178170
}
@@ -185,7 +177,7 @@ struct emulatedWorkgroupScanInclusive
185177

186178
inline void operator()(type_t* outputData, const type_t* workgroupData, uint32_t workgroupSize, uint32_t subgroupSize)
187179
{
188-
outputData[0u] = OP::runOPonFirst ? OP()(0, workgroupData[0]) : workgroupData[0u];
180+
outputData[0u] = workgroupData[0u];
189181
for (auto i=1u; i<workgroupSize; i++)
190182
outputData[i] = OP()(outputData[i-1u],workgroupData[i]);
191183
}
@@ -196,6 +188,7 @@ struct emulatedWorkgroupScanInclusive
196188
#include "common.glsl"
197189
constexpr uint32_t kBufferSize = BUFFER_DWORD_COUNT*sizeof(uint32_t);
198190

191+
199192
//returns true if result matches
200193
template<template<class> class Arithmetic, template<class> class OP>
201194
bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, const uint32_t workgroupSize, const uint32_t workgroupCount, video::IGPUBuffer* bufferToDownload)
@@ -232,18 +225,27 @@ bool validateResults(video::IVideoDriver* driver, const uint32_t* inputData, con
232225
// now check if the data obtained has valid values
233226
constexpr uint32_t subgroupSize = 4u;
234227
uint32_t* tmp = new uint32_t[workgroupSize];
228+
uint32_t* ballotInput = new uint32_t[workgroupSize];
235229
for (uint32_t workgroupID=0u; success&&workgroupID<workgroupCount; workgroupID++)
236230
{
237231
const auto workgroupOffset = workgroupID*workgroupSize;
238-
Arithmetic<OP<uint32_t>>()(tmp,inputData+workgroupOffset,workgroupSize,subgroupSize);
232+
if constexpr (std::is_same_v<OP<uint32_t>,ballot<uint32_t>>)
233+
{
234+
for (auto i=0u; i<workgroupSize; i++)
235+
ballotInput[i] = inputData[i+workgroupOffset]&0x1u;
236+
Arithmetic<OP<uint32_t>>()(tmp,ballotInput,workgroupSize,subgroupSize);
237+
}
238+
else
239+
Arithmetic<OP<uint32_t>>()(tmp,inputData+workgroupOffset,workgroupSize,subgroupSize);
239240
for (uint32_t localInvocationIndex=0u; localInvocationIndex<workgroupSize; localInvocationIndex++)
240241
if (tmp[localInvocationIndex]!=dataFromBuffer[workgroupOffset+localInvocationIndex])
241242
{
242-
os::Printer::log("Failed test #" + std::to_string(workgroupSize) + " (" + Arithmetic<OP<uint32_t>>::name + ") (" + OP<uint32_t>::name + ") Expected "+ std::to_string(dataFromBuffer[workgroupOffset + localInvocationIndex])+ " got " + std::to_string(tmp[localInvocationIndex]), ELL_ERROR);
243+
os::Printer::log("Failed test #" + std::to_string(workgroupSize) + " (" + Arithmetic<OP<uint32_t>>::name + ") (" + OP<uint32_t>::name + ") Expected "+ std::to_string(tmp[localInvocationIndex])+ " got " + std::to_string(dataFromBuffer[workgroupOffset + localInvocationIndex]), ELL_ERROR);
243244
success = false;
244245
break;
245246
}
246247
}
248+
delete[] ballotInput;
247249
delete[] tmp;
248250
}
249251
else
@@ -271,7 +273,7 @@ bool runTest(video::IVideoDriver* driver, video::IGPUComputePipeline* pipeline,
271273
passed = validateResults<Arithmetic,::max>(driver, inputData, workgroupSize, workgroupCount, buffers[6].get())&&passed;
272274
if(is_workgroup_test)
273275
{
274-
passed = validateResults<Arithmetic, countBits>(driver, inputData, workgroupSize, workgroupCount, buffers[7].get()) && passed;
276+
passed = validateResults<Arithmetic,ballot>(driver, inputData, workgroupSize, workgroupCount, buffers[7].get()) && passed;
275277
}
276278

277279
return passed;

include/nbl/builtin/glsl/workgroup/ballot.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ uint nbl_glsl_workgroupBallotScanBitCount_impl(in bool exclusive)
273273
barrier();
274274
}
275275

276-
const uint mask = 0xffffffffu>>((exclusive ? 32u:31u)-(gl_LocalInvocationIndex&31u));
276+
const uint mask = (exclusive ? 0x7fffffffu:0xffffffffu)>>(31u-(gl_LocalInvocationIndex&31u));
277277
return globalCount+bitCount(localBitfield&mask);
278278
}
279279

0 commit comments

Comments
 (0)