Skip to content

Commit 0259cf7

Browse files
author
joaosaffran
committed
adding test
2 parents c3d24b6 + 92b766b commit 0259cf7

File tree

3 files changed

+96
-63
lines changed

3 files changed

+96
-63
lines changed

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313
#include "DXILRootSignature.h"
1414
#include "DirectX.h"
15+
#include "llvm/ADT/STLForwardCompat.h"
1516
#include "llvm/ADT/StringSwitch.h"
1617
#include "llvm/ADT/Twine.h"
1718
#include "llvm/Analysis/DXILMetadataAnalysis.h"
@@ -57,7 +58,7 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
5758

5859
static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5960
unsigned int OpId) {
60-
MDString *NodeText = cast<MDString>(Node->getOperand(OpId));
61+
MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
6162
if (NodeText == nullptr)
6263
return std::nullopt;
6364
return NodeText->getString();
@@ -117,23 +118,31 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
117118

118119
static bool parseRootDescriptors(LLVMContext *Ctx,
119120
mcdxbc::RootSignatureDesc &RSD,
120-
MDNode *RootDescriptorNode) {
121-
121+
MDNode *RootDescriptorNode,
122+
RootSignatureElementKind ElementKind) {
123+
assert(ElementKind == RootSignatureElementKind::SRV ||
124+
ElementKind == RootSignatureElementKind::UAV ||
125+
ElementKind == RootSignatureElementKind::CBV &&
126+
"parseRootDescriptors should only be called with RootDescriptor "
127+
"element kind.");
122128
if (RootDescriptorNode->getNumOperands() != 5)
123129
return reportError(Ctx, "Invalid format for Root Descriptor Element");
124130

125-
std::optional<StringRef> ElementText =
126-
extractMdStringValue(RootDescriptorNode, 0);
127-
128-
if (!ElementText.has_value())
129-
return reportError(Ctx, "Root Descriptor, first element is not a string.");
130-
131131
dxbc::RTS0::v1::RootParameterHeader Header;
132-
Header.ParameterType =
133-
StringSwitch<uint32_t>(*ElementText)
134-
.Case("RootCBV", llvm::to_underlying(dxbc::RootParameterType::CBV))
135-
.Case("RootSRV", llvm::to_underlying(dxbc::RootParameterType::SRV))
136-
.Case("RootUAV", llvm::to_underlying(dxbc::RootParameterType::UAV));
132+
switch (ElementKind) {
133+
case RootSignatureElementKind::SRV:
134+
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
135+
break;
136+
case RootSignatureElementKind::UAV:
137+
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
138+
break;
139+
case RootSignatureElementKind::CBV:
140+
Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
141+
break;
142+
default:
143+
llvm_unreachable("invalid Root Descriptor kind");
144+
break;
145+
}
137146

138147
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
139148
Header.ShaderVisibility = *Val;
@@ -253,17 +262,17 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
253262
static bool parseRootSignatureElement(LLVMContext *Ctx,
254263
mcdxbc::RootSignatureDesc &RSD,
255264
MDNode *Element) {
256-
MDString *ElementText = cast<MDString>(Element->getOperand(0));
257-
if (ElementText == nullptr)
265+
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
266+
if (!ElementText.has_value())
258267
return reportError(Ctx, "Invalid format for Root Element");
259268

260269
RootSignatureElementKind ElementKind =
261-
StringSwitch<RootSignatureElementKind>(ElementText->getString())
270+
StringSwitch<RootSignatureElementKind>(*ElementText)
262271
.Case("RootFlags", RootSignatureElementKind::RootFlags)
263272
.Case("RootConstants", RootSignatureElementKind::RootConstants)
264-
.Case("RootCBV", RootSignatureElementKind::RootDescriptors)
265-
.Case("RootSRV", RootSignatureElementKind::RootDescriptors)
266-
.Case("RootUAV", RootSignatureElementKind::RootDescriptors)
273+
.Case("RootCBV", RootSignatureElementKind::CBV)
274+
.Case("RootSRV", RootSignatureElementKind::SRV)
275+
.Case("RootUAV", RootSignatureElementKind::UAV)
267276
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
268277
.Default(RootSignatureElementKind::Error);
269278

@@ -273,13 +282,14 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
273282
return parseRootFlags(Ctx, RSD, Element);
274283
case RootSignatureElementKind::RootConstants:
275284
return parseRootConstants(Ctx, RSD, Element);
276-
case RootSignatureElementKind::RootDescriptors:
277-
return parseRootDescriptors(Ctx, RSD, Element);
285+
case RootSignatureElementKind::CBV:
286+
case RootSignatureElementKind::SRV:
287+
case RootSignatureElementKind::UAV:
288+
return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
278289
case RootSignatureElementKind::DescriptorTable:
279290
return parseDescriptorTable(Ctx, RSD, Element);
280291
case RootSignatureElementKind::Error:
281-
return reportError(Ctx, "Invalid Root Signature Element: " +
282-
ElementText->getString());
292+
return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
283293
}
284294

285295
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
@@ -308,9 +318,11 @@ static bool verifyVersion(uint32_t Version) {
308318
}
309319

310320
static bool verifyRegisterValue(uint32_t RegisterValue) {
311-
return !(RegisterValue == 0xFFFFFFFF);
321+
return RegisterValue != ~0U;
312322
}
313323

324+
// This Range is reserverved, therefore invalid, according to the spec
325+
// https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
314326
static bool verifyRegisterSpace(uint32_t RegisterSpace) {
315327
return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace < 0xFFFFFFFF);
316328
}
@@ -408,42 +420,42 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
408420
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
409421
"Invalid value for ParameterType");
410422

411-
switch(Info.Header.ParameterType) {
412-
413-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
414-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
415-
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
416-
const dxbc::RTS0::v2::RootDescriptor &Descriptor = RSD.ParametersContainer.getRootDescriptor(Info.Location);
417-
if (!verifyRegisterValue(Descriptor.ShaderRegister))
418-
return reportValueError(Ctx, "ShaderRegister",
419-
Descriptor.ShaderRegister);
420-
421-
if (!verifyRegisterSpace(Descriptor.RegisterSpace))
422-
return reportValueError(Ctx, "RegisterSpace",
423-
Descriptor.RegisterSpace);
424-
425-
if(RSD.Version > 1) {
426-
if (!verifyDescriptorFlag(Descriptor.Flags))
427-
return reportValueError(Ctx, "DescriptorFlag", Descriptor.Flags);
428-
}
429-
break;
430-
}
431-
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
432-
const mcdxbc::DescriptorTable &Table =
433-
RSD.ParametersContainer.getDescriptorTable(Info.Location);
434-
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
435-
if (!verifyRangeType(Range.RangeType))
436-
return reportValueError(Ctx, "RangeType", Range.RangeType);
423+
switch (Info.Header.ParameterType) {
437424

438-
if (!verifyRegisterSpace(Range.RegisterSpace))
439-
return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
425+
case llvm::to_underlying(dxbc::RootParameterType::CBV):
426+
case llvm::to_underlying(dxbc::RootParameterType::UAV):
427+
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
428+
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
429+
RSD.ParametersContainer.getRootDescriptor(Info.Location);
430+
if (!verifyRegisterValue(Descriptor.ShaderRegister))
431+
return reportValueError(Ctx, "ShaderRegister",
432+
Descriptor.ShaderRegister);
440433

441-
if (!verifyDescriptorRangeFlag(RSD.Version, Range.RangeType,
442-
Range.Flags))
443-
return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
444-
}
445-
break;
434+
if (!verifyRegisterSpace(Descriptor.RegisterSpace))
435+
return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
436+
437+
if (RSD.Version > 1) {
438+
if (!verifyDescriptorFlag(Descriptor.Flags))
439+
return reportValueError(Ctx, "DescriptorFlag", Descriptor.Flags);
446440
}
441+
break;
442+
}
443+
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
444+
const mcdxbc::DescriptorTable &Table =
445+
RSD.ParametersContainer.getDescriptorTable(Info.Location);
446+
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
447+
if (!verifyRangeType(Range.RangeType))
448+
return reportValueError(Ctx, "RangeType", Range.RangeType);
449+
450+
if (!verifyRegisterSpace(Range.RegisterSpace))
451+
return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
452+
453+
if (!verifyDescriptorRangeFlag(RSD.Version, Range.RangeType,
454+
Range.Flags))
455+
return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
456+
}
457+
break;
458+
}
447459
}
448460
}
449461

@@ -587,14 +599,14 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
587599
case llvm::to_underlying(dxbc::RootParameterType::CBV):
588600
case llvm::to_underlying(dxbc::RootParameterType::UAV):
589601
case llvm::to_underlying(dxbc::RootParameterType::SRV): {
590-
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
602+
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
591603
RS.ParametersContainer.getRootDescriptor(Loc);
592604
OS << indent(Space + 2)
593605
<< "Register Space: " << Descriptor.RegisterSpace << "\n";
594606
OS << indent(Space + 2)
595607
<< "Shader Register: " << Descriptor.ShaderRegister << "\n";
596-
if(RS.Version > 1)
597-
OS << indent(Space + 2) << "Flags: " << Descriptor.Flags << "\n";
608+
if (RS.Version > 1)
609+
OS << indent(Space + 2) << "Flags: " << Descriptor.Flags << "\n";
598610
break;
599611
}
600612
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ enum class RootSignatureElementKind {
2828
Error = 0,
2929
RootFlags = 1,
3030
RootConstants = 2,
31-
RootDescriptors = 3,
32-
DescriptorTable = 4,
31+
SRV = 3,
32+
UAV = 4,
33+
CBV = 5,
34+
DescriptorTable = 6,
3335
};
3436
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
3537
friend AnalysisInfoMixin<RootSignatureAnalysis>;
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
2+
3+
; CHECK: error: Invalid format for Root Element
4+
; CHECK-NOT: Root Signature Definitions
5+
6+
target triple = "dxil-unknown-shadermodel6.0-compute"
7+
8+
9+
define void @main() #0 {
10+
entry:
11+
ret void
12+
}
13+
14+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
15+
16+
!dx.rootsignatures = !{!0}
17+
!0 = !{ ptr @main, !1 }
18+
!1 = !{ !2 }
19+
!2 = !{ i32 0 }

0 commit comments

Comments
 (0)