@@ -124,7 +124,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
124
124
vector_lv1_t lv1_val;
125
125
[unroll]
126
126
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
127
- scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i),lv1_val[i]);
127
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex< 1 > (invocationIndex, i),lv1_val[i]);
128
128
lv1_val = reduction1 (lv1_val);
129
129
130
130
if (Config::electLast ())
@@ -183,15 +183,14 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
183
183
if (glsl::gl_SubgroupID () == 0 )
184
184
{
185
185
vector_lv1_t lv1_val;
186
- const uint32_t prevIndex = invocationIndex-1 ;
187
186
[unroll]
188
187
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
189
- scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i)-1 ,lv1_val[i]);
188
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex< 1 > (invocationIndex, i)-1 ,lv1_val[i]);
190
189
lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
191
190
lv1_val = inclusiveScan1 (lv1_val);
192
191
[unroll]
193
192
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
194
- scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i),lv1_val[i]);
193
+ scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex< 1 > (invocationIndex, i),lv1_val[i]);
195
194
}
196
195
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
197
196
@@ -253,11 +252,11 @@ struct reduce<Config, BinOp, 3, device_capabilities>
253
252
vector_lv1_t lv1_val;
254
253
[unroll]
255
254
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
256
- scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i),lv1_val[i]);
255
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex< 1 > (invocationIndex, i),lv1_val[i]);
257
256
lv1_val = reduction1 (lv1_val);
258
257
if (Config::electLast ())
259
258
{
260
- const uint32_t bankedIndex = Config::template sharedStoreIndex<2 >(invocationIndex );
259
+ const uint32_t bankedIndex = Config::template sharedStoreIndex<2 >(glsl:: gl_SubgroupID () );
261
260
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1 ]);
262
261
}
263
262
}
@@ -270,7 +269,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
270
269
vector_lv2_t lv2_val;
271
270
[unroll]
272
271
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
273
- scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex (invocationIndex, i),lv2_val[i]);
272
+ scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex< 2 > (invocationIndex, i),lv2_val[i]);
274
273
lv2_val = reduction2 (lv2_val);
275
274
if (Config::electLast ())
276
275
scratchAccessor.template set<scalar_t, uint32_t>(0 , lv2_val[Config::ItemsPerInvocation_2-1 ]);
@@ -309,15 +308,14 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
309
308
if (glsl::gl_SubgroupID () < Config::SubgroupsSize*Config::ItemsPerInvocation_2)
310
309
{
311
310
vector_lv1_t lv1_val;
312
- const uint32_t prevIndex = invocationIndex-1 ;
313
311
[unroll]
314
312
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
315
- scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i)-1 ,lv1_val[i]);
313
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex< 1 > (invocationIndex, i)-1 ,lv1_val[i]);
316
314
lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
317
315
lv1_val = inclusiveScan1 (lv1_val);
318
316
[unroll]
319
317
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
320
- scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i),lv1_val[i]);
318
+ scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex< 1 > (invocationIndex, i),lv1_val[i]);
321
319
if (Config::electLast ())
322
320
{
323
321
const uint32_t bankedIndex = Config::template sharedStoreIndex<2 >(glsl::gl_SubgroupID ());
@@ -331,15 +329,14 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
331
329
if (glsl::gl_SubgroupID () == 0 )
332
330
{
333
331
vector_lv2_t lv2_val;
334
- const uint32_t prevIndex = invocationIndex-1 ;
335
332
[unroll]
336
333
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
337
- scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex (invocationIndex, i)-1 ,lv2_val[i]);
334
+ scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex< 2 > (invocationIndex, i)-1 ,lv2_val[i]);
338
335
lv2_val[0 ] = hlsl::mix (BinOp::identity, lv2_val[0 ], bool (invocationIndex));
339
336
lv2_val = inclusiveScan2 (lv2_val);
340
337
[unroll]
341
338
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
342
- scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex (invocationIndex, i),lv2_val[i]);
339
+ scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex< 2 > (invocationIndex, i),lv2_val[i]);
343
340
}
344
341
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
345
342
@@ -357,7 +354,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
357
354
358
355
[unroll]
359
356
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
360
- scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i), binop (lv1_val[i],lv2_scan));
357
+ scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex< 1 > (invocationIndex, i), binop (lv1_val[i],lv2_scan));
361
358
}
362
359
363
360
// combine with level 0
0 commit comments