Skip to content

Commit 127c6d9

Browse files
committed
some fixes to indexing
1 parent 951ff99 commit 127c6d9

File tree

3 files changed

+17
-16
lines changed

3 files changed

+17
-16
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,13 @@ struct ArithmeticConfiguration
101101
return sharedStoreIndex<level>(virtualID);
102102
}
103103

104+
template<uint16_t level>
104105
static uint32_t sharedLoadIndex(const uint32_t invocationIndex, const uint32_t component)
105106
{
106-
return component * SubgroupSize + invocationIndex;
107+
if (level == LevelCount-1)
108+
return component * SubgroupSize + invocationIndex;
109+
else
110+
return component * __SubgroupsPerVirtualWorkgroup + invocationIndex;
107111
}
108112
};
109113

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
124124
vector_lv1_t lv1_val;
125125
[unroll]
126126
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
127-
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i),lv1_val[i]);
127+
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
128128
lv1_val = reduction1(lv1_val);
129129

130130
if (Config::electLast())
@@ -183,15 +183,14 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
183183
if (glsl::gl_SubgroupID() == 0)
184184
{
185185
vector_lv1_t lv1_val;
186-
const uint32_t prevIndex = invocationIndex-1;
187186
[unroll]
188187
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
189-
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i)-1,lv1_val[i]);
188+
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i)-1,lv1_val[i]);
190189
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
191190
lv1_val = inclusiveScan1(lv1_val);
192191
[unroll]
193192
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
194-
scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i),lv1_val[i]);
193+
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
195194
}
196195
scratchAccessor.workgroupExecutionAndMemoryBarrier();
197196

@@ -253,11 +252,11 @@ struct reduce<Config, BinOp, 3, device_capabilities>
253252
vector_lv1_t lv1_val;
254253
[unroll]
255254
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
256-
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i),lv1_val[i]);
255+
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
257256
lv1_val = reduction1(lv1_val);
258257
if (Config::electLast())
259258
{
260-
const uint32_t bankedIndex = Config::template sharedStoreIndex<2>(invocationIndex);
259+
const uint32_t bankedIndex = Config::template sharedStoreIndex<2>(glsl::gl_SubgroupID());
261260
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
262261
}
263262
}
@@ -270,7 +269,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
270269
vector_lv2_t lv2_val;
271270
[unroll]
272271
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
273-
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex(invocationIndex, i),lv2_val[i]);
272+
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
274273
lv2_val = reduction2(lv2_val);
275274
if (Config::electLast())
276275
scratchAccessor.template set<scalar_t, uint32_t>(0, lv2_val[Config::ItemsPerInvocation_2-1]);
@@ -309,15 +308,14 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
309308
if (glsl::gl_SubgroupID() < Config::SubgroupsSize*Config::ItemsPerInvocation_2)
310309
{
311310
vector_lv1_t lv1_val;
312-
const uint32_t prevIndex = invocationIndex-1;
313311
[unroll]
314312
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
315-
scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i)-1,lv1_val[i]);
313+
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i)-1,lv1_val[i]);
316314
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
317315
lv1_val = inclusiveScan1(lv1_val);
318316
[unroll]
319317
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
320-
scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i),lv1_val[i]);
318+
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
321319
if (Config::electLast())
322320
{
323321
const uint32_t bankedIndex = Config::template sharedStoreIndex<2>(glsl::gl_SubgroupID());
@@ -331,15 +329,14 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
331329
if (glsl::gl_SubgroupID() == 0)
332330
{
333331
vector_lv2_t lv2_val;
334-
const uint32_t prevIndex = invocationIndex-1;
335332
[unroll]
336333
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
337-
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex(invocationIndex, i)-1,lv2_val[i]);
334+
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i)-1,lv2_val[i]);
338335
lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
339336
lv2_val = inclusiveScan2(lv2_val);
340337
[unroll]
341338
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
342-
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex(invocationIndex, i),lv2_val[i]);
339+
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
343340
}
344341
scratchAccessor.workgroupExecutionAndMemoryBarrier();
345342

@@ -357,7 +354,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
357354

358355
[unroll]
359356
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
360-
scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex(invocationIndex, i), binop(lv1_val[i],lv2_scan));
357+
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), binop(lv1_val[i],lv2_scan));
361358
}
362359

363360
// combine with level 0

0 commit comments

Comments
 (0)