@@ -179,15 +179,15 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
179
179
180
180
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
181
181
// level 1 scan
182
- subgroup2::inclusive_scan <params_lv1_t> inclusiveScan1 ;
182
+ subgroup2::exclusive_scan <params_lv1_t> exclusiveScan1 ;
183
183
if (glsl::gl_SubgroupID () == 0 )
184
184
{
185
185
vector_lv1_t lv1_val;
186
186
[unroll]
187
187
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
188
- scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i)- 1 ,lv1_val[i]);
189
- lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
190
- lv1_val = inclusiveScan1 (lv1_val);
188
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
189
+ // lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
190
+ lv1_val = exclusiveScan1 (lv1_val);
191
191
[unroll]
192
192
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
193
193
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
@@ -304,15 +304,16 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
304
304
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
305
305
// level 1 scan
306
306
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
307
- subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
308
- if (glsl::gl_SubgroupID () < Config::SubgroupsSize*Config::ItemsPerInvocation_2)
307
+ const uint32_t lv1_num_invoc = Config::SubgroupsSize*Config::ItemsPerInvocation_2;
308
+ subgroup2::exclusive_scan<params_lv1_t> exclusiveScan1;
309
+ if (glsl::gl_SubgroupID () < lv1_num_invoc)
309
310
{
310
311
vector_lv1_t lv1_val;
311
312
[unroll]
312
313
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
313
- scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i)- 1 ,lv1_val[i]);
314
- lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
315
- lv1_val = inclusiveScan1 (lv1_val);
314
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
315
+ // lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
316
+ lv1_val = exclusiveScan1 (lv1_val);
316
317
[unroll]
317
318
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
318
319
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
@@ -325,23 +326,23 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
325
326
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
326
327
327
328
// level 2 scan
328
- subgroup2::inclusive_scan <params_lv2_t> inclusiveScan2 ;
329
+ subgroup2::exclusive_scan <params_lv2_t> exclusiveScan2 ;
329
330
if (glsl::gl_SubgroupID () == 0 )
330
331
{
331
332
vector_lv2_t lv2_val;
332
333
[unroll]
333
334
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
334
- scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2 >(invocationIndex, i)- 1 ,lv2_val[i]);
335
+ scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2 >(invocationIndex, i),lv2_val[i]);
335
336
lv2_val[0 ] = hlsl::mix (BinOp::identity, lv2_val[0 ], bool (invocationIndex));
336
- lv2_val = inclusiveScan2 (lv2_val);
337
+ lv2_val = exclusiveScan2 (lv2_val);
337
338
[unroll]
338
339
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
339
340
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2 >(invocationIndex, i),lv2_val[i]);
340
341
}
341
342
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
342
343
343
344
// combine with level 1
344
- if (glsl::gl_SubgroupID () < lv1_smem_size )
345
+ if (glsl::gl_SubgroupID () < lv1_num_invoc )
345
346
{
346
347
vector_lv1_t lv1_val;
347
348
[unroll]
0 commit comments