Skip to content

Commit ccacddb

Browse files
committed
store temporaries with data accessor
1 parent 004c95a commit ccacddb

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,21 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
151151
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
152152
BinOp binop;
153153

154-
vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
155154
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
156155
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
157156
// level 0 scan
158157
[unroll]
159158
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
160159
{
161-
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
162-
scan_local[idx] = inclusiveScan0(scan_local[idx]);
160+
vector_lv0_t value;
161+
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
162+
value = inclusiveScan0(value);
163+
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
163164
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
164165
{
165166
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
166167
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
167-
scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
168+
scratchAccessor.template set<scalar_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
168169
}
169170
}
170171
scratchAccessor.workgroupExecutionAndMemoryBarrier();
@@ -188,23 +189,26 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
188189
[unroll]
189190
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
190191
{
192+
vector_lv0_t value;
193+
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
194+
191195
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
192196
scalar_t left;
193197
scratchAccessor.template get<scalar_t>(virtualSubgroupID,left);
194198
if (Exclusive)
195199
{
196-
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
200+
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
197201
[unroll]
198202
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
199-
scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));
203+
value[Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(value[Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));
200204
}
201205
else
202206
{
203207
[unroll]
204208
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
205-
scan_local[idx][i] = binop(left, scan_local[idx][i]);
209+
value[i] = binop(left, value[i]);
206210
}
207-
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
211+
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
208212
}
209213
}
210214
};
@@ -303,20 +307,21 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
303307
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
304308
BinOp binop;
305309

306-
vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
307310
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
308311
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
309312
// level 0 scan
310313
[unroll]
311314
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
312315
{
313-
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
314-
scan_local[idx] = inclusiveScan0(scan_local[idx]);
316+
vector_lv0_t value;
317+
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
318+
value = inclusiveScan0(value);
319+
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
315320
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
316321
{
317322
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
318323
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
319-
scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[idx][Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
324+
scratchAccessor.template set<scalar_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
320325
}
321326
}
322327
scratchAccessor.workgroupExecutionAndMemoryBarrier();
@@ -368,23 +373,26 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
368373
[unroll]
369374
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
370375
{
376+
vector_lv0_t value;
377+
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
378+
371379
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
372380
const scalar_t left;
373381
scratchAccessor.template get<scalar_t>(virtualSubgroupID, left);
374382
if (Exclusive)
375383
{
376-
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx][Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
384+
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
377385
[unroll]
378386
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
379-
scan_local[idx][Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(scan_local[idx][Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));
387+
value[Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(value[Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));
380388
}
381389
else
382390
{
383391
[unroll]
384392
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
385-
scan_local[idx][i] = binop(left, scan_local[idx][i]);
393+
value[i] = binop(left, value[i]);
386394
}
387-
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
395+
dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
388396
}
389397
}
390398
};

0 commit comments

Comments
 (0)