@@ -265,9 +265,10 @@ struct ScratchAccessorAdaptor {
265
265
struct scan_base
266
266
{
267
267
// even if you have a `const uint nbl::hlsl::subgroup::Size` it wont work I think, so `#define` needed
268
- static const uint HalfSubgroupSize = WaveGetLaneCount ()>>1u; // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSize or nbl::hlsl::subgroup::Size
269
- static const uint LoMask = WaveGetLaneCount ()-1u; // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSize
270
- static const uint LastWorkgroupInvocation = _NBL_HLSL_WORKGROUP_SIZE_-1 ; // TODO (PentaKon): Where should this be defined?
268
+ static const uint SubgroupSize = nbl::hlsl::subgroup::subgroupSize ();
269
+ static const uint HalfSubgroupSize = SubgroupSize>>1u; // REVIEW: Is this ok?
270
+ static const uint LoMask = SubgroupSize-1u;
271
+ static const uint LastWorkgroupInvocation = _NBL_HLSL_WORKGROUP_SIZE_-1 ; // REVIEW: Where should this be defined?
271
272
static const uint pseudoSubgroupInvocation = localInvocationIndex&LoMask; // Also used in substructs, thus static const
272
273
273
274
static inclusive_scan<Binop,ScratchAccessor> create ()
@@ -283,7 +284,7 @@ struct scan_base
283
284
retval.scanStoreOffset = paddingMemoryEnd+pseudoSubgroupInvocation;
284
285
285
286
uint reductionResultOffset = paddingMemoryEnd;
286
- if ((LastWorkgroupInvocation>>firstbithigh ( WaveGetLaneCount ())) !=nbl::hlsl::subgroup::ID ()) // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSizeLog2
287
+ if ((LastWorkgroupInvocation>>nbl::hlsl::subgroup:: subgroupSizeLog2 ()) !=nbl::hlsl::subgroup::subgroupInvocationID ())
287
288
retval.reductionResultOffset += LastWorkgroupInvocation&LoMask;
288
289
else
289
290
retval.reductionResultOffset += LoMask;
@@ -319,7 +320,7 @@ struct inclusive_scan : scan_base
319
320
nbl::hlsl::subgroupMemoryBarrierShared ();
320
321
scratchAccessor.set (scanStoreOffset ,value);
321
322
if (scan_base::pseudoSubgroupInvocation<scan_base::HalfSubgroupSize)
322
- scratchAccessor.set (lastLoadOffset,Binop ::identity ());
323
+ scratchAccessor.set (lastLoadOffset,op ::identity ());
323
324
}
324
325
nbl::hlsl::subgroupBarrier ();
325
326
nbl::hlsl::subgroupMemoryBarrierShared ();
@@ -405,7 +406,7 @@ struct reduction
405
406
nbl::hlsl::subgroupBarrier ();
406
407
nbl::hlsl::subgroupMemoryBarrierShared ();
407
408
uint reductionResultOffset = impl.paddingMemoryEnd;
408
- if ((scan_base::LastWorkgroupInvocation>>firstbithigh ( WaveGetLaneCount ())) !=nbl::hlsl::subgroup::ID ()) // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSizeLog2
409
+ if ((scan_base::LastWorkgroupInvocation>>nbl::hlsl::subgroup:: subgroupSizeLog2 ()) !=nbl::hlsl::subgroup::ID ())
409
410
reductionResultOffset += scan_base::LastWorkgroupInvocation & scan_base::LoMask;
410
411
else
411
412
reductionResultOffset += scan_base::LoMask;
0 commit comments