Skip to content

Commit 4648b14

Browse files
committed
Rework the toroidal access on the scatter shader and add comments
1 parent ad26699 commit 4648b14

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

include/nbl/builtin/hlsl/sort/counting.hlsl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ struct counting
5252

5353
sdata.workgroupExecutionAndMemoryBarrier();
5454

55+
// Parallel reads must be coalesced
5556
uint32_t index = workGroupIndex * GroupSize * params.elementsPerWT + tid;
56-
uint32_t endIndex = min(params.dataElementCount, index + GroupSize * params.elementsPerWT);
57+
uint32_t endIndex = min(params.dataElementCount, index + GroupSize * params.elementsPerWT); // implicitly breaks when params.dataElementCount is reached
5758

5859
for (; index < endIndex; index += GroupSize)
5960
{
@@ -62,8 +63,6 @@ struct counting
6263
continue;
6364
sdata.atomicAdd(k - params.minimum, (uint32_t) 1);
6465
}
65-
66-
sdata.workgroupExecutionAndMemoryBarrier();
6766
}
6867

6968
void histogram(
@@ -75,7 +74,11 @@ struct counting
7574
{
7675
build_histogram(key, sdata, params);
7776

77+
// wait for the histogramming to finish
78+
sdata.workgroupExecutionAndMemoryBarrier();
79+
7880
uint32_t tid = workgroup::SubgroupContiguousIndex();
81+
// because first chunk of histogram and workgroup scan scratch are aliased
7982
uint32_t histogram_value = sdata.get(tid);
8083

8184
sdata.workgroupExecutionAndMemoryBarrier();
@@ -91,11 +94,16 @@ struct counting
9194
uint32_t keyBucketStart = GroupSize * i;
9295
uint32_t vid = tid + keyBucketStart;
9396

97+
// no if statement about the last iteration needed
9498
if (is_last_wg_invocation)
9599
{
96100
sdata.set(keyBucketStart, sdata.get(keyBucketStart) + sum);
97101
}
98102

103+
// propagate last block tail to next block head and protect against subsequent scans stepping on each other's toes
104+
sdata.workgroupExecutionAndMemoryBarrier();
105+
106+
// no aliasing anymore
99107
const uint32_t val = vid < KeyBucketCount ? sdata.get(vid) : 0;
100108
sum = inclusive_scan(val, sdata);
101109
histogram.atomicAdd(vid, sum);
@@ -112,15 +120,20 @@ struct counting
112120
{
113121
build_histogram(key, sdata, params);
114122

123+
// wait for the histogramming to finish
124+
sdata.workgroupExecutionAndMemoryBarrier();
125+
115126
uint32_t tid = workgroup::SubgroupContiguousIndex();
116-
uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize() * workGroupIndex) % GroupSize;
127+
const uint32_t shift = glsl::gl_SubgroupSize() * workGroupIndex;
117128

118-
for (; shifted_tid < KeyBucketCount; shifted_tid += GroupSize)
129+
for (uint32_t vtid=tid; vtid<KeyBucketCount; vtid+=GroupSize)
119130
{
120-
uint32_t bucket_value = sdata.get(shifted_tid);
121-
uint32_t exclusive_value = histogram.atomicSub(shifted_tid, bucket_value) - bucket_value;
131+
// have to use modulo operator in case `KeyBucketCount<=2*GroupSize`, better hope KeyBucketCount is Power of Two
132+
const uint32_t shifted_tid = (vtid + shift) % KeyBucketCount;
133+
const uint32_t bucket_value = sdata.get(shifted_tid);
134+
const uint32_t firstOutputIndex = histogram.atomicSub(shifted_tid, bucket_value) - bucket_value;
122135

123-
sdata.set(shifted_tid, exclusive_value);
136+
sdata.set(shifted_tid, firstOutputIndex);
124137
}
125138

126139
sdata.workgroupExecutionAndMemoryBarrier();

0 commit comments

Comments
 (0)