Skip to content

Commit f4d7249

Browse files
committed
Make changes to counting.hlsl
-> Use implicit key_t instead of explicit Key template parameter -> Pass workGroupIndex specifically at Counter creation instead of with uniform data -> Remove global memory toroidal access from the histogram shader
1 parent f46073d commit f4d7249

File tree

2 files changed

+37
-28
lines changed

2 files changed

+37
-28
lines changed

include/nbl/builtin/hlsl/sort/common.hlsl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ struct CountingParameters
1919

2020
uint32_t dataElementCount;
2121
uint32_t elementsPerWT;
22-
uint32_t workGroupIndex;
2322
Key minimum;
2423
Key maximum;
2524
};

include/nbl/builtin/hlsl/sort/counting.hlsl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ namespace sort
1818
template<
1919
uint16_t GroupSize,
2020
uint16_t KeyBucketCount,
21-
typename Key,
2221
typename KeyAccessor,
2322
typename ValueAccessor,
2423
typename HistogramAccessor,
@@ -27,35 +26,33 @@ template<
2726
>
2827
struct counting
2928
{
30-
uint32_t inclusive_scan(uint32_t value, NBL_REF_ARG(SharedAccessor) sdata)
29+
using key_t = decltype(impl::declval < KeyAccessor > ().get(0));
30+
using this_t = counting<GroupSize, KeyBucketCount, KeyAccessor, ValueAccessor, HistogramAccessor, SharedAccessor>;
31+
32+
static this_t create(const uint32_t workGroupIndex)
3133
{
32-
return workgroup::inclusive_scan < plus < uint32_t >, GroupSize >::
33-
template __call <SharedAccessor>(value, sdata);
34+
this_t retval;
35+
retval.workGroupIndex = workGroupIndex;
36+
return retval;
3437
}
3538

36-
uint32_t toroidal_histogram_add(uint32_t tid, uint32_t sum, NBL_REF_ARG(HistogramAccessor) histogram, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
39+
uint32_t inclusive_scan(uint32_t value, NBL_REF_ARG(SharedAccessor) sdata)
3740
{
38-
sdata.workgroupExecutionAndMemoryBarrier();
39-
40-
sdata.set(tid % GroupSize, sum);
41-
uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize() * params.workGroupIndex) % GroupSize;
42-
43-
sdata.workgroupExecutionAndMemoryBarrier();
44-
45-
return histogram.atomicAdd((tid / GroupSize) * GroupSize + shifted_tid, sdata.get(shifted_tid));
41+
return workgroup::inclusive_scan < plus < uint32_t >, GroupSize >::
42+
template __call <SharedAccessor>(value, sdata);
4643
}
4744

48-
void build_histogram(NBL_REF_ARG( KeyAccessor) key, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
45+
void build_histogram(NBL_REF_ARG( KeyAccessor) key, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<key_t> params)
4946
{
5047
uint32_t tid = workgroup::SubgroupContiguousIndex();
5148

52-
for (; tid < KeyBucketCount; tid += GroupSize) {
53-
sdata.set(tid, 0);
49+
for (uint32_t vid = tid; vid < KeyBucketCount; vid += GroupSize) {
50+
sdata.set(vid, 0);
5451
}
5552

5653
sdata.workgroupExecutionAndMemoryBarrier();
5754

58-
uint32_t index = params.workGroupIndex * GroupSize * params.elementsPerWT + tid % GroupSize;
55+
uint32_t index = workGroupIndex * GroupSize * params.elementsPerWT + tid;
5956
uint32_t endIndex = min(params.dataElementCount, index + GroupSize * params.elementsPerWT);
6057

6158
for (; index < endIndex; index += GroupSize)
@@ -69,7 +66,12 @@ struct counting
6966
sdata.workgroupExecutionAndMemoryBarrier();
7067
}
7168

72-
void histogram(NBL_REF_ARG( KeyAccessor) key, NBL_REF_ARG(HistogramAccessor) histogram, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
69+
void histogram(
70+
NBL_REF_ARG( KeyAccessor) key,
71+
NBL_REF_ARG(HistogramAccessor) histogram,
72+
NBL_REF_ARG(SharedAccessor) sdata,
73+
const CountingParameters<key_t> params
74+
)
7375
{
7476
build_histogram(key, sdata, params);
7577

@@ -79,30 +81,36 @@ struct counting
7981
sdata.workgroupExecutionAndMemoryBarrier();
8082

8183
uint32_t sum = inclusive_scan(histogram_value, sdata);
82-
toroidal_histogram_add(tid, sum, histogram, sdata, params);
84+
histogram.atomicAdd(tid, sum);
8385

8486
const bool is_last_wg_invocation = tid == (GroupSize - 1);
8587
const uint16_t adjusted_key_bucket_count = ((KeyBucketCount - 1) / GroupSize + 1) * GroupSize;
8688

87-
for (tid += GroupSize; tid < adjusted_key_bucket_count; tid += GroupSize)
89+
for (uint32_t vid = tid + GroupSize; vid < adjusted_key_bucket_count; vid += GroupSize)
8890
{
8991
if (is_last_wg_invocation)
9092
{
91-
uint32_t startIndex = tid - tid % GroupSize;
93+
uint32_t startIndex = vid - tid;
9294
sdata.set(startIndex, sdata.get(startIndex) + sum);
9395
}
9496

95-
sum = inclusive_scan(sdata.get(tid), sdata);
96-
toroidal_histogram_add(tid, sum, histogram, sdata, params);
97+
sum = inclusive_scan(sdata.get(vid), sdata);
98+
histogram.atomicAdd(vid, sum);
9799
}
98100
}
99101

100-
void scatter(NBL_REF_ARG(KeyAccessor) key, NBL_REF_ARG(ValueAccessor) val, NBL_REF_ARG(HistogramAccessor) histogram, NBL_REF_ARG(SharedAccessor) sdata, const CountingParameters<Key> params)
102+
void scatter(
103+
NBL_REF_ARG( KeyAccessor) key,
104+
NBL_REF_ARG(ValueAccessor) val,
105+
NBL_REF_ARG(HistogramAccessor) histogram,
106+
NBL_REF_ARG(SharedAccessor) sdata,
107+
const CountingParameters<key_t> params
108+
)
101109
{
102110
build_histogram(key, sdata, params);
103111

104112
uint32_t tid = workgroup::SubgroupContiguousIndex();
105-
uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize() * params.workGroupIndex) % GroupSize;
113+
uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize() * workGroupIndex) % GroupSize;
106114

107115
for (; shifted_tid < KeyBucketCount; shifted_tid += GroupSize)
108116
{
@@ -114,13 +122,13 @@ struct counting
114122

115123
sdata.workgroupExecutionAndMemoryBarrier();
116124

117-
uint32_t index = params.workGroupIndex * GroupSize * params.elementsPerWT + tid;
125+
uint32_t index = workGroupIndex * GroupSize * params.elementsPerWT + tid;
118126
uint32_t endIndex = min(params.dataElementCount, index + GroupSize * params.elementsPerWT);
119127

120128
[unroll]
121129
for (; index < endIndex; index += GroupSize)
122130
{
123-
const Key k = key.get(index);
131+
const key_t k = key.get(index);
124132
if (robust && (k<params.minimum || k>params.maximum) )
125133
continue;
126134
const uint32_t v = val.get(index);
@@ -129,6 +137,8 @@ struct counting
129137
val.set(sortedIx, v);
130138
}
131139
}
140+
141+
uint32_t workGroupIndex;
132142
};
133143

134144
}

0 commit comments

Comments
 (0)