Skip to content

Commit f46073d

Browse files
committed
Add toroidal_histogram_add
1 parent e1e028b commit f46073d

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

include/nbl/builtin/hlsl/sort/counting.hlsl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ struct counting
3333
template __call <SharedAccessor>(value, sdata);
3434
}
3535

36+
uint32_t toroidal_histogram_add(uint32_t tid, uint32_t sum, NBL_REF_ARG(HistogramAccessor) histogram, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
37+
{
38+
sdata.workgroupExecutionAndMemoryBarrier();
39+
40+
sdata.set(tid % GroupSize, sum);
41+
uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize() * params.workGroupIndex) % GroupSize;
42+
43+
sdata.workgroupExecutionAndMemoryBarrier();
44+
45+
return histogram.atomicAdd((tid / GroupSize) * GroupSize + shifted_tid, sdata.get(shifted_tid));
46+
}
47+
3648
void build_histogram(NBL_REF_ARG( KeyAccessor) key, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
3749
{
3850
uint32_t tid = workgroup::SubgroupContiguousIndex();
@@ -59,51 +71,50 @@ struct counting
5971

6072
void histogram(NBL_REF_ARG( KeyAccessor) key, NBL_REF_ARG(HistogramAccessor) histogram, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
6173
{
62-
uint32_t tid = workgroup::SubgroupContiguousIndex();
63-
uint32_t buckets_per_thread = (KeyBucketCount - 1) / GroupSize + 1;
64-
6574
build_histogram(key, sdata, params);
6675

76+
uint32_t tid = workgroup::SubgroupContiguousIndex();
6777
uint32_t histogram_value = sdata.get(tid);
6878

6979
sdata.workgroupExecutionAndMemoryBarrier();
7080

7181
uint32_t sum = inclusive_scan(histogram_value, sdata);
72-
histogram.atomicAdd(tid, sum);
82+
toroidal_histogram_add(tid, sum, histogram, sdata, params);
7383

7484
const bool is_last_wg_invocation = tid == (GroupSize - 1);
75-
const uint16_t adjusted_key_bucket_count = KeyBucketCount + (GroupSize - KeyBucketCount % GroupSize);
85+
const uint16_t adjusted_key_bucket_count = ((KeyBucketCount - 1) / GroupSize + 1) * GroupSize;
7686

7787
for (tid += GroupSize; tid < adjusted_key_bucket_count; tid += GroupSize)
7888
{
79-
if (is_last_wg_invocation) {
89+
if (is_last_wg_invocation)
90+
{
8091
uint32_t startIndex = tid - tid % GroupSize;
8192
sdata.set(startIndex, sdata.get(startIndex) + sum);
8293
}
8394

8495
sum = inclusive_scan(sdata.get(tid), sdata);
85-
86-
histogram.atomicAdd(tid, sum);
96+
toroidal_histogram_add(tid, sum, histogram, sdata, params);
8797
}
8898
}
8999

90100
void scatter(NBL_REF_ARG(KeyAccessor) key, NBL_REF_ARG(ValueAccessor) val, NBL_REF_ARG(HistogramAccessor) histogram, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
91101
{
92-
uint32_t tid = workgroup::SubgroupContiguousIndex();
93-
94102
build_histogram(key, sdata, params);
95103

96-
for (; tid < KeyBucketCount; tid += GroupSize)
104+
uint32_t tid = workgroup::SubgroupContiguousIndex();
105+
uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize() * params.workGroupIndex) % GroupSize;
106+
107+
for (; shifted_tid < KeyBucketCount; shifted_tid += GroupSize)
97108
{
98-
uint32_t bucket_value = sdata.get(tid);
99-
uint32_t exclusive_value = histogram.atomicSub(tid, bucket_value) - bucket_value;
109+
uint32_t bucket_value = sdata.get(shifted_tid);
110+
uint32_t exclusive_value = histogram.atomicSub(shifted_tid, bucket_value) - bucket_value;
100111

101-
sdata.set(tid, exclusive_value);
112+
sdata.set(shifted_tid, exclusive_value);
102113
}
103114

104115
sdata.workgroupExecutionAndMemoryBarrier();
105116

106-
uint32_t index = params.workGroupIndex * GroupSize * params.elementsPerWT + tid % GroupSize;
117+
uint32_t index = params.workGroupIndex * GroupSize * params.elementsPerWT + tid;
107118
uint32_t endIndex = min(params.dataElementCount, index + GroupSize * params.elementsPerWT);
108119

109120
[unroll]

0 commit comments

Comments
 (0)