@@ -55,6 +55,14 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
5555 return std::nullopt ;
5656}
5757
58+ static std::optional<StringRef> extractMdStringValue (MDNode *Node,
59+ unsigned int OpId) {
60+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand (OpId));
61+ if (NodeText == nullptr )
62+ return std::nullopt ;
63+ return NodeText->getString ();
64+ }
65+
5866static bool parseRootFlags (LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
5967 MDNode *RootFlagNode) {
6068
@@ -107,17 +115,79 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
107115 return false ;
108116}
109117
118+ static bool parseRootDescriptors (LLVMContext *Ctx,
119+ mcdxbc::RootSignatureDesc &RSD,
120+ MDNode *RootDescriptorNode,
121+ RootSignatureElementKind ElementKind) {
122+ assert (ElementKind == RootSignatureElementKind::SRV ||
123+ ElementKind == RootSignatureElementKind::UAV ||
124+ ElementKind == RootSignatureElementKind::CBV &&
125+ " parseRootDescriptors should only be called with RootDescriptor "
126+ " element kind." );
127+ if (RootDescriptorNode->getNumOperands () != 5 )
128+ return reportError (Ctx, " Invalid format for Root Descriptor Element" );
129+
130+ dxbc::RTS0::v1::RootParameterHeader Header;
131+ switch (ElementKind) {
132+ case RootSignatureElementKind::SRV:
133+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::SRV);
134+ break ;
135+ case RootSignatureElementKind::UAV:
136+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::UAV);
137+ break ;
138+ case RootSignatureElementKind::CBV:
139+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::CBV);
140+ break ;
141+ default :
142+ llvm_unreachable (" invalid Root Descriptor kind" );
143+ break ;
144+ }
145+
146+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 1 ))
147+ Header.ShaderVisibility = *Val;
148+ else
149+ return reportError (Ctx, " Invalid value for ShaderVisibility" );
150+
151+ dxbc::RTS0::v2::RootDescriptor Descriptor;
152+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 2 ))
153+ Descriptor.ShaderRegister = *Val;
154+ else
155+ return reportError (Ctx, " Invalid value for ShaderRegister" );
156+
157+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 3 ))
158+ Descriptor.RegisterSpace = *Val;
159+ else
160+ return reportError (Ctx, " Invalid value for RegisterSpace" );
161+
162+ if (RSD.Version == 1 ) {
163+ RSD.ParametersContainer .addParameter (Header, Descriptor);
164+ return false ;
165+ }
166+ assert (RSD.Version > 1 );
167+
168+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 4 ))
169+ Descriptor.Flags = *Val;
170+ else
171+ return reportError (Ctx, " Invalid value for Root Descriptor Flags" );
172+
173+ RSD.ParametersContainer .addParameter (Header, Descriptor);
174+ return false ;
175+ }
176+
110177static bool parseRootSignatureElement (LLVMContext *Ctx,
111178 mcdxbc::RootSignatureDesc &RSD,
112179 MDNode *Element) {
113- MDString * ElementText = cast<MDString> (Element-> getOperand ( 0 ) );
114- if (ElementText == nullptr )
180+ std::optional<StringRef> ElementText = extractMdStringValue (Element, 0 );
181+ if (! ElementText. has_value () )
115182 return reportError (Ctx, " Invalid format for Root Element" );
116183
117184 RootSignatureElementKind ElementKind =
118- StringSwitch<RootSignatureElementKind>(ElementText-> getString () )
185+ StringSwitch<RootSignatureElementKind>(* ElementText)
119186 .Case (" RootFlags" , RootSignatureElementKind::RootFlags)
120187 .Case (" RootConstants" , RootSignatureElementKind::RootConstants)
188+ .Case (" RootCBV" , RootSignatureElementKind::CBV)
189+ .Case (" RootSRV" , RootSignatureElementKind::SRV)
190+ .Case (" RootUAV" , RootSignatureElementKind::UAV)
121191 .Default (RootSignatureElementKind::Error);
122192
123193 switch (ElementKind) {
@@ -126,10 +196,12 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
126196 return parseRootFlags (Ctx, RSD, Element);
127197 case RootSignatureElementKind::RootConstants:
128198 return parseRootConstants (Ctx, RSD, Element);
129- break ;
199+ case RootSignatureElementKind::CBV:
200+ case RootSignatureElementKind::SRV:
201+ case RootSignatureElementKind::UAV:
202+ return parseRootDescriptors (Ctx, RSD, Element, ElementKind);
130203 case RootSignatureElementKind::Error:
131- return reportError (Ctx, " Invalid Root Signature Element: " +
132- ElementText->getString ());
204+ return reportError (Ctx, " Invalid Root Signature Element: " + *ElementText);
133205 }
134206
135207 llvm_unreachable (" Unhandled RootSignatureElementKind enum." );
@@ -157,6 +229,18 @@ static bool verifyVersion(uint32_t Version) {
157229 return (Version == 1 || Version == 2 );
158230}
159231
232+ static bool verifyRegisterValue (uint32_t RegisterValue) {
233+ return RegisterValue != ~0U ;
234+ }
235+
236+ // This Range is reserverved, therefore invalid, according to the spec
237+ // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
238+ static bool verifyRegisterSpace (uint32_t RegisterSpace) {
239+ return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF );
240+ }
241+
242+ static bool verifyDescriptorFlag (uint32_t Flags) { return (Flags & ~0xE ) == 0 ; }
243+
160244static bool validate (LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
161245
162246 if (!verifyVersion (RSD.Version )) {
@@ -174,6 +258,28 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
174258
175259 assert (dxbc::isValidParameterType (Info.Header .ParameterType ) &&
176260 " Invalid value for ParameterType" );
261+
262+ switch (Info.Header .ParameterType ) {
263+
264+ case llvm::to_underlying (dxbc::RootParameterType::CBV):
265+ case llvm::to_underlying (dxbc::RootParameterType::UAV):
266+ case llvm::to_underlying (dxbc::RootParameterType::SRV): {
267+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
268+ RSD.ParametersContainer .getRootDescriptor (Info.Location );
269+ if (!verifyRegisterValue (Descriptor.ShaderRegister ))
270+ return reportValueError (Ctx, " ShaderRegister" ,
271+ Descriptor.ShaderRegister );
272+
273+ if (!verifyRegisterSpace (Descriptor.RegisterSpace ))
274+ return reportValueError (Ctx, " RegisterSpace" , Descriptor.RegisterSpace );
275+
276+ if (RSD.Version > 1 ) {
277+ if (!verifyDescriptorFlag (Descriptor.Flags ))
278+ return reportValueError (Ctx, " DescriptorFlag" , Descriptor.Flags );
279+ }
280+ break ;
281+ }
282+ }
177283 }
178284
179285 return false ;
@@ -313,6 +419,20 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
313419 << " Shader Register: " << Constants.ShaderRegister << " \n " ;
314420 OS << indent (Space + 2 )
315421 << " Num 32 Bit Values: " << Constants.Num32BitValues << " \n " ;
422+ break ;
423+ }
424+ case llvm::to_underlying (dxbc::RootParameterType::CBV):
425+ case llvm::to_underlying (dxbc::RootParameterType::UAV):
426+ case llvm::to_underlying (dxbc::RootParameterType::SRV): {
427+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
428+ RS.ParametersContainer .getRootDescriptor (Loc);
429+ OS << indent (Space + 2 )
430+ << " Register Space: " << Descriptor.RegisterSpace << " \n " ;
431+ OS << indent (Space + 2 )
432+ << " Shader Register: " << Descriptor.ShaderRegister << " \n " ;
433+ if (RS.Version > 1 )
434+ OS << indent (Space + 2 ) << " Flags: " << Descriptor.Flags << " \n " ;
435+ break ;
316436 }
317437 }
318438 Space--;
0 commit comments