12
12
// ===----------------------------------------------------------------------===//
13
13
#include " DXILRootSignature.h"
14
14
#include " DirectX.h"
15
+ #include " llvm/ADT/STLForwardCompat.h"
15
16
#include " llvm/ADT/StringSwitch.h"
16
17
#include " llvm/ADT/Twine.h"
17
18
#include " llvm/Analysis/DXILMetadataAnalysis.h"
@@ -57,7 +58,7 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
57
58
58
59
static std::optional<StringRef> extractMdStringValue (MDNode *Node,
59
60
unsigned int OpId) {
60
- MDString *NodeText = cast <MDString>(Node->getOperand (OpId));
61
+ MDString *NodeText = dyn_cast <MDString>(Node->getOperand (OpId));
61
62
if (NodeText == nullptr )
62
63
return std::nullopt;
63
64
return NodeText->getString ();
@@ -117,23 +118,31 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
117
118
118
119
static bool parseRootDescriptors (LLVMContext *Ctx,
119
120
mcdxbc::RootSignatureDesc &RSD,
120
- MDNode *RootDescriptorNode) {
121
-
121
+ MDNode *RootDescriptorNode,
122
+ RootSignatureElementKind ElementKind) {
123
+ assert (ElementKind == RootSignatureElementKind::SRV ||
124
+ ElementKind == RootSignatureElementKind::UAV ||
125
+ ElementKind == RootSignatureElementKind::CBV &&
126
+ " parseRootDescriptors should only be called with RootDescriptor "
127
+ " element kind." );
122
128
if (RootDescriptorNode->getNumOperands () != 5 )
123
129
return reportError (Ctx, " Invalid format for Root Descriptor Element" );
124
130
125
- std::optional<StringRef> ElementText =
126
- extractMdStringValue (RootDescriptorNode, 0 );
127
-
128
- if (!ElementText.has_value ())
129
- return reportError (Ctx, " Root Descriptor, first element is not a string." );
130
-
131
131
dxbc::RTS0::v1::RootParameterHeader Header;
132
- Header.ParameterType =
133
- StringSwitch<uint32_t >(*ElementText)
134
- .Case (" RootCBV" , llvm::to_underlying (dxbc::RootParameterType::CBV))
135
- .Case (" RootSRV" , llvm::to_underlying (dxbc::RootParameterType::SRV))
136
- .Case (" RootUAV" , llvm::to_underlying (dxbc::RootParameterType::UAV));
132
+ switch (ElementKind) {
133
+ case RootSignatureElementKind::SRV:
134
+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::SRV);
135
+ break ;
136
+ case RootSignatureElementKind::UAV:
137
+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::UAV);
138
+ break ;
139
+ case RootSignatureElementKind::CBV:
140
+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::CBV);
141
+ break ;
142
+ default :
143
+ llvm_unreachable (" invalid Root Descriptor kind" );
144
+ break ;
145
+ }
137
146
138
147
if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 1 ))
139
148
Header.ShaderVisibility = *Val;
@@ -253,17 +262,17 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
253
262
static bool parseRootSignatureElement (LLVMContext *Ctx,
254
263
mcdxbc::RootSignatureDesc &RSD,
255
264
MDNode *Element) {
256
- MDString * ElementText = cast<MDString> (Element-> getOperand ( 0 ) );
257
- if (ElementText == nullptr )
265
+ std::optional<StringRef> ElementText = extractMdStringValue (Element, 0 );
266
+ if (! ElementText. has_value () )
258
267
return reportError (Ctx, " Invalid format for Root Element" );
259
268
260
269
RootSignatureElementKind ElementKind =
261
- StringSwitch<RootSignatureElementKind>(ElementText-> getString () )
270
+ StringSwitch<RootSignatureElementKind>(* ElementText)
262
271
.Case (" RootFlags" , RootSignatureElementKind::RootFlags)
263
272
.Case (" RootConstants" , RootSignatureElementKind::RootConstants)
264
- .Case (" RootCBV" , RootSignatureElementKind::RootDescriptors )
265
- .Case (" RootSRV" , RootSignatureElementKind::RootDescriptors )
266
- .Case (" RootUAV" , RootSignatureElementKind::RootDescriptors )
273
+ .Case (" RootCBV" , RootSignatureElementKind::CBV )
274
+ .Case (" RootSRV" , RootSignatureElementKind::SRV )
275
+ .Case (" RootUAV" , RootSignatureElementKind::UAV )
267
276
.Case (" DescriptorTable" , RootSignatureElementKind::DescriptorTable)
268
277
.Default (RootSignatureElementKind::Error);
269
278
@@ -273,13 +282,14 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
273
282
return parseRootFlags (Ctx, RSD, Element);
274
283
case RootSignatureElementKind::RootConstants:
275
284
return parseRootConstants (Ctx, RSD, Element);
276
- case RootSignatureElementKind::RootDescriptors:
277
- return parseRootDescriptors (Ctx, RSD, Element);
285
+ case RootSignatureElementKind::CBV:
286
+ case RootSignatureElementKind::SRV:
287
+ case RootSignatureElementKind::UAV:
288
+ return parseRootDescriptors (Ctx, RSD, Element, ElementKind);
278
289
case RootSignatureElementKind::DescriptorTable:
279
290
return parseDescriptorTable (Ctx, RSD, Element);
280
291
case RootSignatureElementKind::Error:
281
- return reportError (Ctx, " Invalid Root Signature Element: " +
282
- ElementText->getString ());
292
+ return reportError (Ctx, " Invalid Root Signature Element: " + *ElementText);
283
293
}
284
294
285
295
llvm_unreachable (" Unhandled RootSignatureElementKind enum." );
@@ -308,9 +318,11 @@ static bool verifyVersion(uint32_t Version) {
308
318
}
309
319
310
320
static bool verifyRegisterValue (uint32_t RegisterValue) {
311
- return !( RegisterValue == 0xFFFFFFFF ) ;
321
+ return RegisterValue != ~ 0U ;
312
322
}
313
323
324
+ // This Range is reserverved, therefore invalid, according to the spec
325
+ // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
314
326
static bool verifyRegisterSpace (uint32_t RegisterSpace) {
315
327
return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace < 0xFFFFFFFF );
316
328
}
@@ -408,42 +420,42 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
408
420
assert (dxbc::isValidParameterType (Info.Header .ParameterType ) &&
409
421
" Invalid value for ParameterType" );
410
422
411
- switch (Info.Header .ParameterType ) {
412
-
413
- case llvm::to_underlying (dxbc::RootParameterType::CBV):
414
- case llvm::to_underlying (dxbc::RootParameterType::UAV):
415
- case llvm::to_underlying (dxbc::RootParameterType::SRV): {
416
- const dxbc::RTS0::v2::RootDescriptor &Descriptor = RSD.ParametersContainer .getRootDescriptor (Info.Location );
417
- if (!verifyRegisterValue (Descriptor.ShaderRegister ))
418
- return reportValueError (Ctx, " ShaderRegister" ,
419
- Descriptor.ShaderRegister );
420
-
421
- if (!verifyRegisterSpace (Descriptor.RegisterSpace ))
422
- return reportValueError (Ctx, " RegisterSpace" ,
423
- Descriptor.RegisterSpace );
424
-
425
- if (RSD.Version > 1 ) {
426
- if (!verifyDescriptorFlag (Descriptor.Flags ))
427
- return reportValueError (Ctx, " DescriptorFlag" , Descriptor.Flags );
428
- }
429
- break ;
430
- }
431
- case llvm::to_underlying (dxbc::RootParameterType::DescriptorTable): {
432
- const mcdxbc::DescriptorTable &Table =
433
- RSD.ParametersContainer .getDescriptorTable (Info.Location );
434
- for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
435
- if (!verifyRangeType (Range.RangeType ))
436
- return reportValueError (Ctx, " RangeType" , Range.RangeType );
423
+ switch (Info.Header .ParameterType ) {
437
424
438
- if (!verifyRegisterSpace (Range.RegisterSpace ))
439
- return reportValueError (Ctx, " RegisterSpace" , Range.RegisterSpace );
425
+ case llvm::to_underlying (dxbc::RootParameterType::CBV):
426
+ case llvm::to_underlying (dxbc::RootParameterType::UAV):
427
+ case llvm::to_underlying (dxbc::RootParameterType::SRV): {
428
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
429
+ RSD.ParametersContainer .getRootDescriptor (Info.Location );
430
+ if (!verifyRegisterValue (Descriptor.ShaderRegister ))
431
+ return reportValueError (Ctx, " ShaderRegister" ,
432
+ Descriptor.ShaderRegister );
440
433
441
- if (!verifyDescriptorRangeFlag (RSD.Version , Range.RangeType ,
442
- Range.Flags ))
443
- return reportValueError (Ctx, " DescriptorFlag" , Range.Flags );
444
- }
445
- break ;
434
+ if (!verifyRegisterSpace (Descriptor.RegisterSpace ))
435
+ return reportValueError (Ctx, " RegisterSpace" , Descriptor.RegisterSpace );
436
+
437
+ if (RSD.Version > 1 ) {
438
+ if (!verifyDescriptorFlag (Descriptor.Flags ))
439
+ return reportValueError (Ctx, " DescriptorFlag" , Descriptor.Flags );
446
440
}
441
+ break ;
442
+ }
443
+ case llvm::to_underlying (dxbc::RootParameterType::DescriptorTable): {
444
+ const mcdxbc::DescriptorTable &Table =
445
+ RSD.ParametersContainer .getDescriptorTable (Info.Location );
446
+ for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
447
+ if (!verifyRangeType (Range.RangeType ))
448
+ return reportValueError (Ctx, " RangeType" , Range.RangeType );
449
+
450
+ if (!verifyRegisterSpace (Range.RegisterSpace ))
451
+ return reportValueError (Ctx, " RegisterSpace" , Range.RegisterSpace );
452
+
453
+ if (!verifyDescriptorRangeFlag (RSD.Version , Range.RangeType ,
454
+ Range.Flags ))
455
+ return reportValueError (Ctx, " DescriptorFlag" , Range.Flags );
456
+ }
457
+ break ;
458
+ }
447
459
}
448
460
}
449
461
@@ -587,14 +599,14 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
587
599
case llvm::to_underlying (dxbc::RootParameterType::CBV):
588
600
case llvm::to_underlying (dxbc::RootParameterType::UAV):
589
601
case llvm::to_underlying (dxbc::RootParameterType::SRV): {
590
- const dxbc::RTS0::v2::RootDescriptor &Descriptor =
602
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
591
603
RS.ParametersContainer .getRootDescriptor (Loc);
592
604
OS << indent (Space + 2 )
593
605
<< " Register Space: " << Descriptor.RegisterSpace << " \n " ;
594
606
OS << indent (Space + 2 )
595
607
<< " Shader Register: " << Descriptor.ShaderRegister << " \n " ;
596
- if (RS.Version > 1 )
597
- OS << indent (Space + 2 ) << " Flags: " << Descriptor.Flags << " \n " ;
608
+ if (RS.Version > 1 )
609
+ OS << indent (Space + 2 ) << " Flags: " << Descriptor.Flags << " \n " ;
598
610
break ;
599
611
}
600
612
case llvm::to_underlying (dxbc::RootParameterType::DescriptorTable): {
0 commit comments