@@ -120,7 +120,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
120
120
vector_lv1_t lv1_val;
121
121
[unroll]
122
122
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
123
- scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup +invocationIndex,lv1_val[i]);
123
+ scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize +invocationIndex,lv1_val[i]);
124
124
lv1_val = reduction1 (lv1_val);
125
125
126
126
if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
@@ -176,12 +176,12 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
176
176
const uint32_t prevIndex = invocationIndex-1 ;
177
177
[unroll]
178
178
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
179
- scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup +prevIndex,lv1_val[i]);
179
+ scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize +prevIndex,lv1_val[i]);
180
180
lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
181
181
lv1_val = inclusiveScan1 (lv1_val);
182
182
[unroll]
183
183
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
184
- scratchAccessor.template set<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup +invocationIndex,lv1_val[i]);
184
+ scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize +invocationIndex,lv1_val[i]);
185
185
}
186
186
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
187
187
@@ -258,7 +258,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
258
258
vector_lv1_t lv1_val;
259
259
[unroll]
260
260
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
261
- scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup +invocationIndex,lv1_val[i]);
261
+ scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize +invocationIndex,lv1_val[i]);
262
262
lv1_val = reduction1 (lv1_val);
263
263
if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
264
264
{
@@ -275,7 +275,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
275
275
vector_lv2_t lv2_val;
276
276
[unroll]
277
277
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
278
- scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup +invocationIndex,lv2_val[i]);
278
+ scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize +invocationIndex,lv2_val[i]);
279
279
lv2_val = reduction2 (lv2_val);
280
280
scratchAccessor.template set<scalar_t>(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1 ]);
281
281
}
@@ -324,15 +324,20 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
324
324
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
325
325
326
326
// level 1 scan
327
- const uint32_t lv1_smem_size = Config::SubgroupsPerVirtualWorkgroup *Config::ItemsPerInvocation_1;
327
+ const uint32_t lv1_smem_size = Config::SubgroupsSize *Config::ItemsPerInvocation_1;
328
328
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
329
329
if (glsl::gl_SubgroupID () < lv1_smem_size)
330
330
{
331
331
vector_lv1_t lv1_val;
332
+ const uint32_t prevIndex = invocationIndex-1 ;
332
333
[unroll]
333
334
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
334
- scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
335
+ scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+prevIndex,lv1_val[i]);
336
+ lv1_val[0 ] = hlsl::mix (BinOp::identity, lv1_val[0 ], bool (invocationIndex));
335
337
lv1_val = inclusiveScan1 (lv1_val);
338
+ [unroll]
339
+ for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
340
+ scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
336
341
if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
337
342
{
338
343
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex (glsl::gl_SubgroupID (), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
@@ -351,37 +356,48 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
351
356
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
352
357
scratchAccessor.template get<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]);
353
358
lv2_val[0 ] = hlsl::mix (hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val[0 ], bool (invocationIndex));
354
- vector_lv2_t shiftedScan = inclusiveScan2 (lv2_val);
355
-
356
- // combine with level 1, only last element of each
359
+ lv2_val = inclusiveScan2 (lv2_val);
357
360
[unroll]
358
- for (uint32_t i = 0 ; i < Config::SubgroupsPerVirtualWorkgroup; i++)
359
- {
360
- scalar_t last_val;
361
- scratchAccessor.template get<scalar_t>((Config::ItemsPerInvocation_1-1 )*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1 -i),last_val);
362
- scalar_t val = hlsl::mix (hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool (i));
363
- val = binop (last_val, shiftedScan[Config::ItemsPerInvocation_2-1 ]);
364
- scratchAccessor.template set<scalar_t>((Config::ItemsPerInvocation_1-1 )*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1 -i), last_val);
365
- }
361
+ for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
362
+ scratchAccessor.template set<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+invocationIndex,lv2_val[i]);
366
363
}
367
364
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
368
365
366
+ // combine with level 1
367
+ if (glsl::gl_SubgroupID () < lv1_smem_size)
368
+ {
369
+ vector_lv1_t lv1_val;
370
+ [unroll]
371
+ for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
372
+ scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
373
+
374
+ scalar_t lv2_scan;
375
+ const uint32_t bankedIndex = Config::sharedMemCoalescedIndex (glsl::gl_SubgroupID (), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
376
+ scratchAccessor.template set<scalar_t>(lv1_smem_size+bankedIndex, lv2_scan);
377
+
378
+ [unroll]
379
+ for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
380
+ scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex, binop (lv1_val[i],lv2_scan));
381
+ }
382
+
369
383
// combine with level 0
370
384
[unroll]
371
385
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
372
386
{
373
387
vector_lv0_t value;
374
388
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
375
389
376
- const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx); // idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
377
- const scalar_t left;
378
- scratchAccessor.template get<scalar_t>(virtualSubgroupID, left);
390
+ const uint32_t virtualSubgroupID = Config::virtualSubgroupID (glsl::gl_SubgroupID (), idx);
391
+ const uint32_t bankedIndex = Config::sharedMemCoalescedIndex (virtualSubgroupID, Config::ItemsPerInvocation_1);
392
+ scalar_t left;
393
+ scratchAccessor.template get<scalar_t>(bankedIndex,left);
379
394
if (Exclusive)
380
395
{
381
396
scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
382
397
[unroll]
383
- for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_0; i++)
384
- value[Config::ItemsPerInvocation_0-i-1 ] = binop (left, hlsl::mix (value[Config::ItemsPerInvocation_0-i-2 ], left_last_elem, (Config::ItemsPerInvocation_0-i-1 ==0 )));
398
+ for (uint32_t i = Config::ItemsPerInvocation_0-1 ; i > 0 ; i--)
399
+ value[i] = binop (left, value[i-1 ]);
400
+ value[0 ] = binop (left, left_last_elem);
385
401
}
386
402
else
387
403
{
0 commit comments