Skip to content

Commit 23186d1

Browse files
joaosaffranjoaosaffran
andcommitted
[HLSL] Adding support for Root Constants in LLVM Metadata (llvm#135085)
- Closes [llvm#126637](llvm#126637) --------- Co-authored-by: joaosaffran <[email protected]>
1 parent 78522ac commit 23186d1

File tree

3 files changed

+61
-84
lines changed

3 files changed

+61
-84
lines changed

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 45 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,19 @@ static bool reportError(LLVMContext *Ctx, Twine Message,
4040
return true;
4141
}
4242

43-
static bool reportValueError(LLVMContext *Ctx, Twine ParamName, uint32_t Value,
44-
DiagnosticSeverity Severity = DS_Error) {
43+
static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
44+
uint32_t Value) {
4545
Ctx->diagnose(DiagnosticInfoGeneric(
46-
"Invalid value for " + ParamName + ": " + Twine(Value), Severity));
46+
"Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
4747
return true;
4848
}
4949

50-
static bool extractMdIntValue(uint32_t &Value, MDNode *Node,
51-
unsigned int OpId) {
52-
auto *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get());
53-
if (CI == nullptr)
54-
return true;
55-
56-
Value = CI->getZExtValue();
57-
return false;
50+
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
51+
unsigned int OpId) {
52+
if (auto *CI =
53+
mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
54+
return CI->getZExtValue();
55+
return std::nullopt;
5856
}
5957

6058
static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
@@ -63,7 +61,9 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
6361
if (RootFlagNode->getNumOperands() != 2)
6462
return reportError(Ctx, "Invalid format for RootFlag Element");
6563

66-
if (extractMdIntValue(RSD.Flags, RootFlagNode, 1))
64+
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
65+
RSD.Flags = *Val;
66+
else
6767
return reportError(Ctx, "Invalid value for RootFlag");
6868

6969
return false;
@@ -79,22 +79,24 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
7979
NewParameter.Header.ParameterType =
8080
llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
8181

82-
uint32_t SV;
83-
if (extractMdIntValue(SV, RootConstantNode, 1))
82+
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
83+
NewParameter.Header.ShaderVisibility = *Val;
84+
else
8485
return reportError(Ctx, "Invalid value for ShaderVisibility");
8586

86-
NewParameter.Header.ShaderVisibility = SV;
87-
88-
if (extractMdIntValue(NewParameter.Constants.ShaderRegister, RootConstantNode,
89-
2))
87+
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
88+
NewParameter.Constants.ShaderRegister = *Val;
89+
else
9090
return reportError(Ctx, "Invalid value for ShaderRegister");
9191

92-
if (extractMdIntValue(NewParameter.Constants.RegisterSpace, RootConstantNode,
93-
3))
92+
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
93+
NewParameter.Constants.RegisterSpace = *Val;
94+
else
9495
return reportError(Ctx, "Invalid value for RegisterSpace");
9596

96-
if (extractMdIntValue(NewParameter.Constants.Num32BitValues, RootConstantNode,
97-
4))
97+
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
98+
NewParameter.Constants.Num32BitValues = *Val;
99+
else
98100
return reportError(Ctx, "Invalid value for Num32BitValues");
99101

100102
RSD.Parameters.push_back(NewParameter);
@@ -148,32 +150,6 @@ static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
148150

149151
static bool verifyRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
150152

151-
static bool verifyShaderVisibility(uint32_t Flags) {
152-
switch (Flags) {
153-
154-
case llvm::to_underlying(dxbc::ShaderVisibility::All):
155-
case llvm::to_underlying(dxbc::ShaderVisibility::Vertex):
156-
case llvm::to_underlying(dxbc::ShaderVisibility::Hull):
157-
case llvm::to_underlying(dxbc::ShaderVisibility::Domain):
158-
case llvm::to_underlying(dxbc::ShaderVisibility::Geometry):
159-
case llvm::to_underlying(dxbc::ShaderVisibility::Pixel):
160-
case llvm::to_underlying(dxbc::ShaderVisibility::Amplification):
161-
case llvm::to_underlying(dxbc::ShaderVisibility::Mesh):
162-
return true;
163-
}
164-
165-
return false;
166-
}
167-
168-
static bool verifyParameterType(uint32_t Type) {
169-
switch (Type) {
170-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
171-
return true;
172-
}
173-
174-
return false;
175-
}
176-
177153
static bool verifyVersion(uint32_t Version) {
178154
return (Version == 1 || Version == 2);
179155
}
@@ -188,12 +164,12 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
188164
return reportValueError(Ctx, "RootFlags", RSD.Flags);
189165
}
190166

191-
for (const auto &P : RSD.Parameters) {
192-
if (!verifyShaderVisibility(P.Header.ShaderVisibility))
167+
for (const mcdxbc::RootParameter &P : RSD.Parameters) {
168+
if (!dxbc::isValidShaderVisibility(P.Header.ShaderVisibility))
193169
return reportValueError(Ctx, "ShaderVisibility",
194-
(uint32_t)P.Header.ShaderVisibility);
170+
P.Header.ShaderVisibility);
195171

196-
assert(verifyParameterType(P.Header.ParameterType) &&
172+
assert(dxbc::isValidParameterType(P.Header.ParameterType) &&
197173
"Invalid value for ParameterType");
198174
}
199175

@@ -265,6 +241,10 @@ analyzeModule(Module &M) {
265241
}
266242

267243
mcdxbc::RootSignatureDesc RSD;
244+
// Clang emits the root signature data in dxcontainer following a specific
245+
// sequence. First the header, then the root parameters. So the header
246+
// offset will always equal to the header size.
247+
RSD.RootParameterOffset = sizeof(dxbc::RootSignatureHeader);
268248

269249
if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
270250
return RSDMap;
@@ -291,7 +271,6 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
291271
SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> &RSDMap =
292272
AM.getResult<RootSignatureAnalysis>(M);
293273

294-
const size_t RSHSize = sizeof(dxbc::RootSignatureHeader);
295274
OS << "Root Signature Definitions"
296275
<< "\n";
297276
uint8_t Space = 0;
@@ -306,32 +285,30 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
306285
Space++;
307286
OS << indent(Space) << "Flags: " << format_hex(RS.Flags, 8) << "\n";
308287
OS << indent(Space) << "Version: " << RS.Version << "\n";
309-
OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n";
310-
OS << indent(Space) << "RootParametersOffset: " << RSHSize << "\n";
311-
OS << indent(Space) << "NumStaticSamplers: " << 0 << "\n";
312-
OS << indent(Space)
313-
<< "StaticSamplersOffset: " << RSHSize + RS.Parameters.size_in_bytes()
288+
OS << indent(Space) << "RootParametersOffset: " << RS.RootParameterOffset
314289
<< "\n";
315-
290+
OS << indent(Space) << "NumParameters: " << RS.Parameters.size() << "\n";
316291
Space++;
317292
for (auto const &P : RS.Parameters) {
318-
OS << indent(Space)
319-
<< "Parameter Type: " << (uint32_t)P.Header.ParameterType << "\n";
320-
OS << indent(Space)
321-
<< "Shader Visibility: " << (uint32_t)P.Header.ShaderVisibility
293+
OS << indent(Space) << "- Parameter Type: " << P.Header.ParameterType
322294
<< "\n";
295+
OS << indent(Space + 2)
296+
<< "Shader Visibility: " << P.Header.ShaderVisibility << "\n";
323297
switch (P.Header.ParameterType) {
324298
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
325-
OS << indent(Space) << "Register Space: " << P.Constants.RegisterSpace
326-
<< "\n";
327-
OS << indent(Space) << "Shader Register: " << P.Constants.ShaderRegister
328-
<< "\n";
329-
OS << indent(Space)
299+
OS << indent(Space + 2)
300+
<< "Register Space: " << P.Constants.RegisterSpace << "\n";
301+
OS << indent(Space + 2)
302+
<< "Shader Register: " << P.Constants.ShaderRegister << "\n";
303+
OS << indent(Space + 2)
330304
<< "Num 32 Bit Values: " << P.Constants.Num32BitValues << "\n";
331305
break;
332306
}
333307
}
334308
Space--;
309+
OS << indent(Space) << "NumStaticSamplers: " << 0 << "\n";
310+
OS << indent(Space) << "StaticSamplersOffset: " << RS.StaticSamplersOffset
311+
<< "\n";
335312

336313
Space--;
337314
// end root signature header

llvm/test/CodeGen/DirectX/ContainerData/RootSignature-MultipleEntryFunctions.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
2626
; CHECK-LABEL: Definition for 'main':
2727
; CHECK-NEXT: Flags: 0x000001
2828
; CHECK-NEXT: Version: 2
29-
; CHECK-NEXT: NumParameters: 0
3029
; CHECK-NEXT: RootParametersOffset: 24
30+
; CHECK-NEXT: NumParameters: 0
3131
; CHECK-NEXT: NumStaticSamplers: 0
32-
; CHECK-NEXT: StaticSamplersOffset: 24
32+
; CHECK-NEXT: StaticSamplersOffset: 0
3333

3434
; CHECK-LABEL: Definition for 'anotherMain':
3535
; CHECK-NEXT: Flags: 0x000002
3636
; CHECK-NEXT: Version: 2
37-
; CHECK-NEXT: NumParameters: 0
3837
; CHECK-NEXT: RootParametersOffset: 24
38+
; CHECK-NEXT: NumParameters: 0
3939
; CHECK-NEXT: NumStaticSamplers: 0
40-
; CHECK-NEXT: StaticSamplersOffset: 24
40+
; CHECK-NEXT: StaticSamplersOffset: 0

llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Parameters.ll

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
1616
!4 = !{ !"RootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout
1717
!5 = !{ !"RootConstants", i32 0, i32 1, i32 2, i32 3 }
1818

19-
; CHECK-LABEL: Definition for 'main':
20-
; CHECK-NEXT: Flags: 0x000001
21-
; CHECK-NEXT: Version: 2
22-
; CHECK-NEXT: NumParameters: 1
23-
; CHECK-NEXT: RootParametersOffset: 24
24-
; CHECK-NEXT: NumStaticSamplers: 0
25-
; CHECK-NEXT: StaticSamplersOffset: 48
26-
; CHECK-NEXT: Parameter Type: 1
27-
; CHECK-NEXT: Shader Visibility: 0
28-
; CHECK-NEXT: Register Space: 2
29-
; CHECK-NEXT: Shader Register: 1
30-
; CHECK-NEXT: Num 32 Bit Values: 3
19+
;CHECK-LABEL: Definition for 'main':
20+
;CHECK-NEXT: Flags: 0x000001
21+
;CHECK-NEXT: Version: 2
22+
;CHECK-NEXT: RootParametersOffset: 24
23+
;CHECK-NEXT: NumParameters: 1
24+
;CHECK-NEXT: - Parameter Type: 1
25+
;CHECK-NEXT: Shader Visibility: 0
26+
;CHECK-NEXT: Register Space: 2
27+
;CHECK-NEXT: Shader Register: 1
28+
;CHECK-NEXT: Num 32 Bit Values: 3
29+
;CHECK-NEXT: NumStaticSamplers: 0
30+
;CHECK-NEXT: StaticSamplersOffset: 0

0 commit comments

Comments
 (0)