@@ -33,6 +33,18 @@ struct counting
33
33
template __call <SharedAccessor>(value, sdata);
34
34
}
35
35
36
+ uint32_t toroidal_histogram_add (uint32_t tid, uint32_t sum, NBL_REF_ARG (HistogramAccessor) histogram, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<Key> params)
37
+ {
38
+ sdata.workgroupExecutionAndMemoryBarrier ();
39
+
40
+ sdata.set (tid % GroupSize, sum);
41
+ uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize () * params.workGroupIndex) % GroupSize;
42
+
43
+ sdata.workgroupExecutionAndMemoryBarrier ();
44
+
45
+ return histogram.atomicAdd ((tid / GroupSize) * GroupSize + shifted_tid, sdata.get (shifted_tid));
46
+ }
47
+
36
48
void build_histogram (NBL_REF_ARG ( KeyAccessor) key, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<Key> params)
37
49
{
38
50
uint32_t tid = workgroup::SubgroupContiguousIndex ();
@@ -59,51 +71,50 @@ struct counting
59
71
60
72
void histogram (NBL_REF_ARG ( KeyAccessor) key, NBL_REF_ARG (HistogramAccessor) histogram, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<Key> params)
61
73
{
62
- uint32_t tid = workgroup::SubgroupContiguousIndex ();
63
- uint32_t buckets_per_thread = (KeyBucketCount - 1 ) / GroupSize + 1 ;
64
-
65
74
build_histogram (key, sdata, params);
66
75
76
+ uint32_t tid = workgroup::SubgroupContiguousIndex ();
67
77
uint32_t histogram_value = sdata.get (tid);
68
78
69
79
sdata.workgroupExecutionAndMemoryBarrier ();
70
80
71
81
uint32_t sum = inclusive_scan (histogram_value, sdata);
72
- histogram. atomicAdd (tid, sum);
82
+ toroidal_histogram_add (tid, sum, histogram, sdata, params );
73
83
74
84
const bool is_last_wg_invocation = tid == (GroupSize - 1 );
75
- const uint16_t adjusted_key_bucket_count = KeyBucketCount + ( GroupSize - KeyBucketCount % GroupSize) ;
85
+ const uint16_t adjusted_key_bucket_count = (( KeyBucketCount - 1 ) / GroupSize + 1 ) * GroupSize;
76
86
77
87
for (tid += GroupSize; tid < adjusted_key_bucket_count; tid += GroupSize)
78
88
{
79
- if (is_last_wg_invocation) {
89
+ if (is_last_wg_invocation)
90
+ {
80
91
uint32_t startIndex = tid - tid % GroupSize;
81
92
sdata.set (startIndex, sdata.get (startIndex) + sum);
82
93
}
83
94
84
95
sum = inclusive_scan (sdata.get (tid), sdata);
85
-
86
- histogram.atomicAdd (tid, sum);
96
+ toroidal_histogram_add (tid, sum, histogram, sdata, params);
87
97
}
88
98
}
89
99
90
100
void scatter (NBL_REF_ARG (KeyAccessor) key, NBL_REF_ARG (ValueAccessor) val, NBL_REF_ARG (HistogramAccessor) histogram, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<Key> params)
91
101
{
92
- uint32_t tid = workgroup::SubgroupContiguousIndex ();
93
-
94
102
build_histogram (key, sdata, params);
95
103
96
- for (; tid < KeyBucketCount; tid += GroupSize)
104
+ uint32_t tid = workgroup::SubgroupContiguousIndex ();
105
+ uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize () * params.workGroupIndex) % GroupSize;
106
+
107
+ for (; shifted_tid < KeyBucketCount; shifted_tid += GroupSize)
97
108
{
98
- uint32_t bucket_value = sdata.get (tid );
99
- uint32_t exclusive_value = histogram.atomicSub (tid , bucket_value) - bucket_value;
109
+ uint32_t bucket_value = sdata.get (shifted_tid );
110
+ uint32_t exclusive_value = histogram.atomicSub (shifted_tid , bucket_value) - bucket_value;
100
111
101
- sdata.set (tid , exclusive_value);
112
+ sdata.set (shifted_tid , exclusive_value);
102
113
}
103
114
104
115
sdata.workgroupExecutionAndMemoryBarrier ();
105
116
106
- uint32_t index = params.workGroupIndex * GroupSize * params.elementsPerWT + tid % GroupSize ;
117
+ uint32_t index = params.workGroupIndex * GroupSize * params.elementsPerWT + tid;
107
118
uint32_t endIndex = min (params.dataElementCount, index + GroupSize * params.elementsPerWT);
108
119
109
120
[unroll]
0 commit comments