Skip to content

Commit 3782a34

Browse files
committed
Add binop::identity() as well as attempt at SPIR-V implementations of subgroup barrier ops
1 parent 8863194 commit 3782a34

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed

include/nbl/builtin/hlsl/binops.hlsl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ struct bitwise_and
1717
{
1818
return lhs&rhs;
1919
}
20+
21+
T identity()
22+
{
23+
return ~0;
24+
}
2025
};
2126

2227
template<typename T>
@@ -26,6 +31,11 @@ struct bitwise_or
2631
{
2732
return lhs|rhs;
2833
}
34+
35+
T identity()
36+
{
37+
return 0;
38+
}
2939
};
3040

3141
template<typename T>
@@ -35,6 +45,11 @@ struct bitwise_xor
3545
{
3646
return lhs^rhs;
3747
}
48+
49+
T identity()
50+
{
51+
return 0;
52+
}
3853
};
3954

4055
template<typename T>
@@ -44,6 +59,11 @@ struct add
4459
{
4560
return lhs+rhs;
4661
}
62+
63+
T identity()
64+
{
65+
return 0;
66+
}
4767
};
4868

4969
template<typename T>
@@ -53,6 +73,11 @@ struct mul
5373
{
5474
return lhs*rhs;
5575
}
76+
77+
T identity()
78+
{
79+
return 1;
80+
}
5681
};
5782

5883
template<typename T, class Comparator>
@@ -72,6 +97,11 @@ struct min
7297
comparator_lt_t comp;
7398
return bitwise_min(lhs, rhs, comp);
7499
}
100+
101+
T identity()
102+
{
103+
return ~0;
104+
}
75105
};
76106

77107
template<typename T, class Comparator>
@@ -91,6 +121,11 @@ struct max
91121
comparator_gt_t comp;
92122
return bitwise_max(lhs, rhs, comp);
93123
}
124+
125+
T identity()
126+
{
127+
return 0; // REVIEW: This assumes T = unsigned but what if we got T = signed ?
128+
}
94129
};
95130

96131
template<typename T>

include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,10 @@ struct ScratchAccessorAdaptor {
265265
struct scan_base
266266
{
267267
// even if you have a `const uint nbl::hlsl::subgroup::Size` it wont work I think, so `#define` needed
268-
static const uint HalfSubgroupSize = WaveGetLaneCount()>>1u; // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSize or nbl::hlsl::subgroup::Size
269-
static const uint LoMask = WaveGetLaneCount()-1u; // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSize
270-
static const uint LastWorkgroupInvocation = _NBL_HLSL_WORKGROUP_SIZE_-1; // TODO (PentaKon): Where should this be defined?
268+
static const uint SubgroupSize = nbl::hlsl::subgroup::subgroupSize();
269+
static const uint HalfSubgroupSize = SubgroupSize>>1u; // REVIEW: Is this ok?
270+
static const uint LoMask = SubgroupSize-1u;
271+
static const uint LastWorkgroupInvocation = _NBL_HLSL_WORKGROUP_SIZE_-1; // REVIEW: Where should this be defined?
271272
static const uint pseudoSubgroupInvocation = localInvocationIndex&LoMask; // Also used in substructs, thus static const
272273

273274
static inclusive_scan<Binop,ScratchAccessor> create()
@@ -283,7 +284,7 @@ struct scan_base
283284
retval.scanStoreOffset = paddingMemoryEnd+pseudoSubgroupInvocation;
284285

285286
uint reductionResultOffset = paddingMemoryEnd;
286-
if ((LastWorkgroupInvocation>>firstbithigh(WaveGetLaneCount()))!=nbl::hlsl::subgroup::ID()) // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSizeLog2
287+
if ((LastWorkgroupInvocation>>nbl::hlsl::subgroup::subgroupSizeLog2())!=nbl::hlsl::subgroup::subgroupInvocationID())
287288
retval.reductionResultOffset += LastWorkgroupInvocation&LoMask;
288289
else
289290
retval.reductionResultOffset += LoMask;
@@ -319,7 +320,7 @@ struct inclusive_scan : scan_base
319320
nbl::hlsl::subgroupMemoryBarrierShared();
320321
scratchAccessor.set(scanStoreOffset ,value);
321322
if (scan_base::pseudoSubgroupInvocation<scan_base::HalfSubgroupSize)
322-
scratchAccessor.set(lastLoadOffset,Binop::identity());
323+
scratchAccessor.set(lastLoadOffset,op::identity());
323324
}
324325
nbl::hlsl::subgroupBarrier();
325326
nbl::hlsl::subgroupMemoryBarrierShared();
@@ -405,7 +406,7 @@ struct reduction
405406
nbl::hlsl::subgroupBarrier();
406407
nbl::hlsl::subgroupMemoryBarrierShared();
407408
uint reductionResultOffset = impl.paddingMemoryEnd;
408-
if ((scan_base::LastWorkgroupInvocation>>firstbithigh(WaveGetLaneCount()))!=nbl::hlsl::subgroup::ID()) // TODO (PentaKon): Replace with nbl_hlsl_SubgroupSizeLog2
409+
if ((scan_base::LastWorkgroupInvocation>>nbl::hlsl::subgroup::subgroupSizeLog2())!=nbl::hlsl::subgroup::ID())
409410
reductionResultOffset += scan_base::LastWorkgroupInvocation & scan_base::LoMask;
410411
else
411412
reductionResultOffset += scan_base::LoMask;

include/nbl/builtin/hlsl/subgroup/basic_portability.hlsl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ namespace hlsl
77
{
88
static const uint MaxWorkgroupSizeLog2 = 11;
99
static const uint MaxWorkgroupSize = 0x1 << MaxWorkgroupSizeLog2;
10+
11+
namespace subgroup
12+
{
1013
static const uint MinSubgroupSizeLog2 = 2;
1114
static const uint MinSubgroupSize = 0x1 << MinSubgroupSizeLog2;
1215

@@ -60,9 +63,27 @@ namespace hlsl
6063

6164
// WAVE BARRIERS
6265

66+
// REVIEW: Review everything related to subgroup barriers and SPIR-V injection
67+
68+
[[vk::ext_instruction(/* subgroupBarrier */ 224)]] // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpControlBarrier
69+
void spirv_subgroupBarrier(uint executionScope, uint memoryScope, uint memorySemantics);
70+
6371
void subgroupBarrier() {
72+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_scope_id
73+
// Subgroup scope is number 3
6474

75+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_memory_semantics_id
76+
// By providing memory semantics None we do both control and memory barrier as is done in GLSL
77+
spirv_subgroupBarrier(3, 3, 0x0);
6578
}
79+
80+
[[vk::ext_instruction(/* subgroupMemoryBarrierShared */ 225)]] // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpControlBarrier
81+
void spirv_subgroupMemoryBarrierShared(uint memoryScope, uint memorySemantics);
82+
83+
void subgroupMemoryBarrierShared() {
84+
spirv_subgroupMemoryBarrierShared(3, 0x0); // REVIEW: Need advice on memory semantics. Would think SubgroupMemory(0x80) but have no idea
85+
}
86+
}
6687
}
6788
}
6889

0 commit comments

Comments
 (0)