@@ -18,7 +18,6 @@ namespace sort
18
18
template<
19
19
uint16_t GroupSize,
20
20
uint16_t KeyBucketCount,
21
- typename Key,
22
21
typename KeyAccessor,
23
22
typename ValueAccessor,
24
23
typename HistogramAccessor,
@@ -27,35 +26,33 @@ template<
27
26
>
28
27
struct counting
29
28
{
30
- uint32_t inclusive_scan (uint32_t value, NBL_REF_ARG (SharedAccessor) sdata)
29
+ using key_t = decltype (impl::declval < KeyAccessor > ().get (0 ));
30
+ using this_t = counting<GroupSize, KeyBucketCount, KeyAccessor, ValueAccessor, HistogramAccessor, SharedAccessor>;
31
+
32
+ static this_t create (const uint32_t workGroupIndex)
31
33
{
32
- return workgroup::inclusive_scan < plus < uint32_t >, GroupSize >::
33
- template __call <SharedAccessor>(value, sdata);
34
+ this_t retval;
35
+ retval.workGroupIndex = workGroupIndex;
36
+ return retval;
34
37
}
35
38
36
- uint32_t toroidal_histogram_add (uint32_t tid, uint32_t sum, NBL_REF_ARG (HistogramAccessor) histogram, NBL_REF_ARG ( SharedAccessor) sdata, const CountingParameters<Key> params )
39
+ uint32_t inclusive_scan (uint32_t value, NBL_REF_ARG (SharedAccessor) sdata)
37
40
{
38
- sdata.workgroupExecutionAndMemoryBarrier ();
39
-
40
- sdata.set (tid % GroupSize, sum);
41
- uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize () * params.workGroupIndex) % GroupSize;
42
-
43
- sdata.workgroupExecutionAndMemoryBarrier ();
44
-
45
- return histogram.atomicAdd ((tid / GroupSize) * GroupSize + shifted_tid, sdata.get (shifted_tid));
41
+ return workgroup::inclusive_scan < plus < uint32_t >, GroupSize >::
42
+ template __call <SharedAccessor>(value, sdata);
46
43
}
47
44
48
- void build_histogram (NBL_REF_ARG ( KeyAccessor) key, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<Key > params)
45
+ void build_histogram (NBL_REF_ARG ( KeyAccessor) key, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<key_t > params)
49
46
{
50
47
uint32_t tid = workgroup::SubgroupContiguousIndex ();
51
48
52
- for (; tid < KeyBucketCount; tid += GroupSize) {
53
- sdata.set (tid , 0 );
49
+ for (uint32_t vid = tid; vid < KeyBucketCount; vid += GroupSize) {
50
+ sdata.set (vid , 0 );
54
51
}
55
52
56
53
sdata.workgroupExecutionAndMemoryBarrier ();
57
54
58
- uint32_t index = params. workGroupIndex * GroupSize * params.elementsPerWT + tid % GroupSize ;
55
+ uint32_t index = workGroupIndex * GroupSize * params.elementsPerWT + tid;
59
56
uint32_t endIndex = min (params.dataElementCount, index + GroupSize * params.elementsPerWT);
60
57
61
58
for (; index < endIndex; index += GroupSize)
@@ -69,7 +66,12 @@ struct counting
69
66
sdata.workgroupExecutionAndMemoryBarrier ();
70
67
}
71
68
72
- void histogram (NBL_REF_ARG ( KeyAccessor) key, NBL_REF_ARG (HistogramAccessor) histogram, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<Key> params)
69
+ void histogram (
70
+ NBL_REF_ARG ( KeyAccessor) key,
71
+ NBL_REF_ARG (HistogramAccessor) histogram,
72
+ NBL_REF_ARG (SharedAccessor) sdata,
73
+ const CountingParameters<key_t> params
74
+ )
73
75
{
74
76
build_histogram (key, sdata, params);
75
77
@@ -79,30 +81,36 @@ struct counting
79
81
sdata.workgroupExecutionAndMemoryBarrier ();
80
82
81
83
uint32_t sum = inclusive_scan (histogram_value, sdata);
82
- toroidal_histogram_add (tid, sum, histogram, sdata, params );
84
+ histogram. atomicAdd (tid, sum);
83
85
84
86
const bool is_last_wg_invocation = tid == (GroupSize - 1 );
85
87
const uint16_t adjusted_key_bucket_count = ((KeyBucketCount - 1 ) / GroupSize + 1 ) * GroupSize;
86
88
87
- for (tid += GroupSize; tid < adjusted_key_bucket_count; tid += GroupSize)
89
+ for (uint32_t vid = tid + GroupSize; vid < adjusted_key_bucket_count; vid += GroupSize)
88
90
{
89
91
if (is_last_wg_invocation)
90
92
{
91
- uint32_t startIndex = tid - tid % GroupSize ;
93
+ uint32_t startIndex = vid - tid;
92
94
sdata.set (startIndex, sdata.get (startIndex) + sum);
93
95
}
94
96
95
- sum = inclusive_scan (sdata.get (tid ), sdata);
96
- toroidal_histogram_add (tid , sum, histogram, sdata, params );
97
+ sum = inclusive_scan (sdata.get (vid ), sdata);
98
+ histogram. atomicAdd (vid , sum);
97
99
}
98
100
}
99
101
100
- void scatter (NBL_REF_ARG (KeyAccessor) key, NBL_REF_ARG (ValueAccessor) val, NBL_REF_ARG (HistogramAccessor) histogram, NBL_REF_ARG (SharedAccessor) sdata, const CountingParameters<Key> params)
102
+ void scatter (
103
+ NBL_REF_ARG ( KeyAccessor) key,
104
+ NBL_REF_ARG (ValueAccessor) val,
105
+ NBL_REF_ARG (HistogramAccessor) histogram,
106
+ NBL_REF_ARG (SharedAccessor) sdata,
107
+ const CountingParameters<key_t> params
108
+ )
101
109
{
102
110
build_histogram (key, sdata, params);
103
111
104
112
uint32_t tid = workgroup::SubgroupContiguousIndex ();
105
- uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize () * params. workGroupIndex) % GroupSize;
113
+ uint32_t shifted_tid = (tid + glsl::gl_SubgroupSize () * workGroupIndex) % GroupSize;
106
114
107
115
for (; shifted_tid < KeyBucketCount; shifted_tid += GroupSize)
108
116
{
@@ -114,13 +122,13 @@ struct counting
114
122
115
123
sdata.workgroupExecutionAndMemoryBarrier ();
116
124
117
- uint32_t index = params. workGroupIndex * GroupSize * params.elementsPerWT + tid;
125
+ uint32_t index = workGroupIndex * GroupSize * params.elementsPerWT + tid;
118
126
uint32_t endIndex = min (params.dataElementCount, index + GroupSize * params.elementsPerWT);
119
127
120
128
[unroll]
121
129
for (; index < endIndex; index += GroupSize)
122
130
{
123
- const Key k = key.get (index);
131
+ const key_t k = key.get (index);
124
132
if (robust && (k<params.minimum || k>params.maximum) )
125
133
continue ;
126
134
const uint32_t v = val.get (index);
@@ -129,6 +137,8 @@ struct counting
129
137
val.set (sortedIx, v);
130
138
}
131
139
}
140
+
141
+ uint32_t workGroupIndex;
132
142
};
133
143
134
144
}
0 commit comments