Skip to content

Commit b062ede

Browse files
committed
simplified indexing functions
1 parent d758ff7 commit b062ede

File tree

2 files changed

+27
-21
lines changed

2 files changed

+27
-21
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,20 @@ struct ArithmeticConfiguration
7171
return workgroupInVirtualIndex * (WorkgroupSize >> SubgroupSizeLog2) + subgroupID;
7272
}
7373

74-
static uint32_t sharedStoreIndex(const uint32_t subgroupID, const uint32_t itemsPerInvocation)
74+
template<uint16_t level>
75+
static uint32_t sharedStoreIndex(const uint32_t subgroupID)
7576
{
76-
return (subgroupID & (itemsPerInvocation-1)) * SubgroupSize + (subgroupID/itemsPerInvocation);
77+
if (level<2)
78+
return (subgroupID & (ItemsPerInvocation_1-1)) * SubgroupSize + (subgroupID/ItemsPerInvocation_1);
79+
else
80+
return (subgroupID & (ItemsPerInvocation_2-1)) * SubgroupSize + (subgroupID/ItemsPerInvocation_2);
81+
}
82+
83+
template<uint16_t level>
84+
static uint32_t sharedStoreIndexFromVirtualIndex(const uint32_t subgroupID, const uint32_t workgroupInVirtualIndex)
85+
{
86+
const uint32_t virtualID = virtualSubgroupID(subgroupID, workgroupInVirtualIndex);
87+
return sharedStoreIndex<level>(virtualID);
7788
}
7889

7990
static uint32_t sharedLoadIndex(const uint32_t invocationIndex, const uint32_t component)

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
104104
scan_local = reduction0(scan_local);
105105
if (Config::electLast())
106106
{
107-
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
108-
const uint32_t bankedIndex = Config::sharedStoreIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
107+
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
109108
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
110109
}
111110
}
@@ -159,8 +158,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
159158
dataAccessor.template set<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
160159
if (Config::electLast())
161160
{
162-
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
163-
const uint32_t bankedIndex = Config::sharedStoreIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
161+
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
164162
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
165163
}
166164
}
@@ -174,7 +172,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
174172
const uint32_t prevIndex = invocationIndex-1;
175173
[unroll]
176174
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
177-
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(prevIndex, i),lv1_val[i]);
175+
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i)-1,lv1_val[i]);
178176
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
179177
lv1_val = inclusiveScan1(lv1_val);
180178
[unroll]
@@ -190,8 +188,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
190188
vector_lv0_t value;
191189
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
192190

193-
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
194-
const uint32_t bankedIndex = Config::sharedStoreIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
191+
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
195192
scalar_t left;
196193
scratchAccessor.template get<scalar_t, uint32_t>(bankedIndex,left);
197194
if (Exclusive)
@@ -242,8 +239,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
242239
scan_local = reduction0(scan_local);
243240
if (Config::electLast())
244241
{
245-
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
246-
const uint32_t bankedIndex = Config::sharedStoreIndex(virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
242+
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
247243
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
248244
}
249245
}
@@ -261,7 +257,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
261257
lv1_val = reduction1(lv1_val);
262258
if (Config::electLast())
263259
{
264-
const uint32_t bankedIndex = Config::sharedStoreIndex(invocationIndex, Config::ItemsPerInvocation_2); // (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2);
260+
const uint32_t bankedIndex = Config::template sharedStoreIndex<2>(invocationIndex);
265261
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
266262
}
267263
}
@@ -276,7 +272,8 @@ struct reduce<Config, BinOp, 3, device_capabilities>
276272
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
277273
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex(invocationIndex, i),lv2_val[i]);
278274
lv2_val = reduction2(lv2_val);
279-
scratchAccessor.template set<scalar_t, uint32_t>(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]);
275+
if (Config::electLast())
276+
scratchAccessor.template set<scalar_t, uint32_t>(0, lv2_val[Config::ItemsPerInvocation_2-1]);
280277
}
281278
scratchAccessor.workgroupExecutionAndMemoryBarrier();
282279

@@ -315,8 +312,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
315312
dataAccessor.template set<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
316313
if (Config::electLast())
317314
{
318-
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
319-
const uint32_t bankedIndex = Config::sharedStoreIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
315+
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
320316
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
321317
}
322318
}
@@ -331,15 +327,15 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
331327
const uint32_t prevIndex = invocationIndex-1;
332328
[unroll]
333329
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
334-
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(prevIndex, i),lv1_val[i]);
330+
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i)-1,lv1_val[i]);
335331
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
336332
lv1_val = inclusiveScan1(lv1_val);
337333
[unroll]
338334
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
339335
scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i),lv1_val[i]);
340336
if (Config::electLast())
341337
{
342-
const uint32_t bankedIndex = Config::sharedStoreIndex(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
338+
const uint32_t bankedIndex = Config::template sharedStoreIndex<2>(glsl::gl_SubgroupID());
343339
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
344340
}
345341
}
@@ -353,7 +349,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
353349
const uint32_t prevIndex = invocationIndex-1;
354350
[unroll]
355351
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
356-
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex(prevIndex, i),lv2_val[i]);
352+
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex(invocationIndex, i)-1,lv2_val[i]);
357353
lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
358354
lv2_val = inclusiveScan2(lv2_val);
359355
[unroll]
@@ -371,7 +367,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
371367
scratchAccessor.template get<scalar_t, uint32_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
372368

373369
scalar_t lv2_scan;
374-
const uint32_t bankedIndex = Config::sharedStoreIndex(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
370+
const uint32_t bankedIndex = Config::template sharedStoreIndex<2>(glsl::gl_SubgroupID());
375371
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv2_scan);
376372

377373
[unroll]
@@ -386,8 +382,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
386382
vector_lv0_t value;
387383
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
388384

389-
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
390-
const uint32_t bankedIndex = Config::sharedStoreIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
385+
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
391386
scalar_t left;
392387
scratchAccessor.template get<scalar_t, uint32_t>(bankedIndex,left);
393388
if (Exclusive)

0 commit comments

Comments
 (0)