@@ -104,8 +104,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
104
104
scan_local = reduction0 (scan_local);
105
105
if (Config::electLast ())
106
106
{
107
- const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
108
- const uint32_t bankedIndex = Config::sharedStoreIndex (virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
107
+ const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
109
108
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
110
109
}
111
110
}
@@ -159,8 +158,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
159
158
dataAccessor.template set<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
160
159
if (Config::electLast ())
161
160
{
162
- const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
163
- const uint32_t bankedIndex = Config::sharedStoreIndex (virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
161
+ const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
164
162
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
165
163
}
166
164
}
@@ -174,7 +172,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
174
172
const uint32_t prevIndex = invocationIndex-1 ;
175
173
[unroll]
176
174
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
177
- scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (prevIndex , i),lv1_val[i]);
175
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex , i)- 1 ,lv1_val[i]);
178
176
lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
179
177
lv1_val = inclusiveScan1 (lv1_val);
180
178
[unroll]
@@ -190,8 +188,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
190
188
vector_lv0_t value;
191
189
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
192
190
193
- const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
194
- const uint32_t bankedIndex = Config::sharedStoreIndex (virtualSubgroupID, Config::ItemsPerInvocation_1);
191
+ const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
195
192
scalar_t left;
196
193
scratchAccessor.template get<scalar_t, uint32_t>(bankedIndex,left);
197
194
if (Exclusive)
@@ -242,8 +239,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
242
239
scan_local = reduction0 (scan_local);
243
240
if (Config::electLast ())
244
241
{
245
- const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
246
- const uint32_t bankedIndex = Config::sharedStoreIndex (virtualSubgroupID, Config::ItemsPerInvocation_1); // (virtualSubgroupID & (Config::ItemsPerInvocation_1-1)) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
242
+ const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
247
243
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
248
244
}
249
245
}
@@ -261,7 +257,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
261
257
lv1_val = reduction1 (lv1_val);
262
258
if (Config::electLast ())
263
259
{
264
- const uint32_t bankedIndex = Config::sharedStoreIndex (invocationIndex, Config::ItemsPerInvocation_2); // (invocationIndex & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupsPerVirtualWorkgroup + (invocationIndex/Config::ItemsPerInvocation_2 );
260
+ const uint32_t bankedIndex = Config::template sharedStoreIndex< 2 > (invocationIndex);
265
261
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1 ]);
266
262
}
267
263
}
@@ -276,7 +272,8 @@ struct reduce<Config, BinOp, 3, device_capabilities>
276
272
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
277
273
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex (invocationIndex, i),lv2_val[i]);
278
274
lv2_val = reduction2 (lv2_val);
279
- scratchAccessor.template set<scalar_t, uint32_t>(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1 ]);
275
+ if (Config::electLast ())
276
+ scratchAccessor.template set<scalar_t, uint32_t>(0 , lv2_val[Config::ItemsPerInvocation_2-1 ]);
280
277
}
281
278
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
282
279
@@ -315,8 +312,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
315
312
dataAccessor.template set<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
316
313
if (Config::electLast ())
317
314
{
318
- const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
319
- const uint32_t bankedIndex = Config::sharedStoreIndex (virtualSubgroupID, Config::ItemsPerInvocation_1);
315
+ const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
320
316
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
321
317
}
322
318
}
@@ -331,15 +327,15 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
331
327
const uint32_t prevIndex = invocationIndex-1 ;
332
328
[unroll]
333
329
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
334
- scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (prevIndex , i),lv1_val[i]);
330
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex , i)- 1 ,lv1_val[i]);
335
331
lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
336
332
lv1_val = inclusiveScan1 (lv1_val);
337
333
[unroll]
338
334
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
339
335
scratchAccessor.template set<scalar_t, uint32_t>(Config::sharedLoadIndex (invocationIndex, i),lv1_val[i]);
340
336
if (Config::electLast ())
341
337
{
342
- const uint32_t bankedIndex = Config::sharedStoreIndex (glsl::gl_SubgroupID (), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2 );
338
+ const uint32_t bankedIndex = Config::template sharedStoreIndex< 2 > (glsl::gl_SubgroupID ());
343
339
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1 ]);
344
340
}
345
341
}
@@ -353,7 +349,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
353
349
const uint32_t prevIndex = invocationIndex-1 ;
354
350
[unroll]
355
351
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
356
- scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex (prevIndex , i),lv2_val[i]);
352
+ scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::sharedLoadIndex (invocationIndex , i)- 1 ,lv2_val[i]);
357
353
lv2_val[0 ] = hlsl::mix (BinOp::identity, lv2_val[0 ], bool (invocationIndex));
358
354
lv2_val = inclusiveScan2 (lv2_val);
359
355
[unroll]
@@ -371,7 +367,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
371
367
scratchAccessor.template get<scalar_t, uint32_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
372
368
373
369
scalar_t lv2_scan;
374
- const uint32_t bankedIndex = Config::sharedStoreIndex (glsl::gl_SubgroupID (), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2 );
370
+ const uint32_t bankedIndex = Config::template sharedStoreIndex< 2 > (glsl::gl_SubgroupID ());
375
371
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv2_scan);
376
372
377
373
[unroll]
@@ -386,8 +382,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
386
382
vector_lv0_t value;
387
383
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
388
384
389
- const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
390
- const uint32_t bankedIndex = Config::sharedStoreIndex (virtualSubgroupID, Config::ItemsPerInvocation_1);
385
+ const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
391
386
scalar_t left;
392
387
scratchAccessor.template get<scalar_t, uint32_t>(bankedIndex,left);
393
388
if (Exclusive)
0 commit comments