1212// ===----------------------------------------------------------------------===//
1313#include " DXILRootSignature.h"
1414#include " DirectX.h"
15+ #include " llvm/ADT/StringRef.h"
1516#include " llvm/ADT/StringSwitch.h"
1617#include " llvm/ADT/Twine.h"
1718#include " llvm/Analysis/DXILMetadataAnalysis.h"
3031#include < cmath>
3132#include < cstdint>
3233#include < optional>
34+ #include < string>
3335#include < utility>
3436
3537using namespace llvm ;
@@ -48,6 +50,71 @@ static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
4850 return true ;
4951}
5052
53+ // Template function to get formatted type string based on C++ type
54+ template <typename T> std::string getTypeFormatted () {
55+ if constexpr (std::is_same_v<T, MDString>) {
56+ return " string" ;
57+ } else if constexpr (std::is_same_v<T, MDNode *> ||
58+ std::is_same_v<T, const MDNode *>) {
59+ return " metadata" ;
60+ } else if constexpr (std::is_same_v<T, ConstantAsMetadata *> ||
61+ std::is_same_v<T, const ConstantAsMetadata *>) {
62+ return " constant" ;
63+ } else if constexpr (std::is_same_v<T, ConstantAsMetadata>) {
64+ return " constant" ;
65+ } else if constexpr (std::is_same_v<T, ConstantInt *> ||
66+ std::is_same_v<T, const ConstantInt *>) {
67+ return " constant int" ;
68+ } else if constexpr (std::is_same_v<T, ConstantInt>) {
69+ return " constant int" ;
70+ }
71+ return " unknown" ;
72+ }
73+
74+ // Helper function to get the actual type of a metadata operand
75+ std::string getActualMDType (const MDNode *Node, unsigned Index) {
76+ if (!Node || Index >= Node->getNumOperands ())
77+ return " null" ;
78+
79+ Metadata *Op = Node->getOperand (Index);
80+ if (!Op)
81+ return " null" ;
82+
83+ if (isa<MDString>(Op))
84+ return getTypeFormatted<MDString>();
85+
86+ if (isa<ConstantAsMetadata>(Op)) {
87+ if (auto *CAM = dyn_cast<ConstantAsMetadata>(Op)) {
88+ Type *T = CAM->getValue ()->getType ();
89+ if (T->isIntegerTy ())
90+ return (Twine (" i" ) + Twine (T->getIntegerBitWidth ())).str ();
91+ if (T->isFloatingPointTy ())
92+ return T->isFloatTy () ? getTypeFormatted<float >()
93+ : T->isDoubleTy () ? getTypeFormatted<double >()
94+ : " fp" ;
95+
96+ return getTypeFormatted<ConstantAsMetadata>();
97+ }
98+ }
99+ if (isa<MDNode>(Op))
100+ return getTypeFormatted<MDNode *>();
101+
102+ return " unknown" ;
103+ }
104+
105+ // Helper function to simplify error reporting for invalid metadata values
106+ template <typename ET>
107+ auto reportInvalidTypeError (LLVMContext *Ctx, Twine ParamName,
108+ const MDNode *Node, unsigned Index) {
109+ std::string ExpectedType = getTypeFormatted<ET>();
110+ std::string ActualType = getActualMDType (Node, Index);
111+
112+ return reportError (Ctx, " Root Signature Node: " + ParamName +
113+ " expected metadata node of type " +
114+ ExpectedType + " at index " + Twine (Index) +
115+ " but got " + ActualType);
116+ }
117+
51118static std::optional<uint32_t > extractMdIntValue (MDNode *Node,
52119 unsigned int OpId) {
53120 if (auto *CI =
@@ -80,7 +147,8 @@ static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
80147 if (std::optional<uint32_t > Val = extractMdIntValue (RootFlagNode, 1 ))
81148 RSD.Flags = *Val;
82149 else
83- return reportError (Ctx, " Invalid value for RootFlag" );
150+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootFlagNode" ,
151+ RootFlagNode, 1 );
84152
85153 return false ;
86154}
@@ -100,23 +168,27 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
100168 if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 1 ))
101169 Header.ShaderVisibility = *Val;
102170 else
103- return reportError (Ctx, " Invalid value for ShaderVisibility" );
171+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
172+ RootConstantNode, 1 );
104173
105174 dxbc::RTS0::v1::RootConstants Constants;
106175 if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 2 ))
107176 Constants.ShaderRegister = *Val;
108177 else
109- return reportError (Ctx, " Invalid value for ShaderRegister" );
178+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
179+ RootConstantNode, 2 );
110180
111181 if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 3 ))
112182 Constants.RegisterSpace = *Val;
113183 else
114- return reportError (Ctx, " Invalid value for RegisterSpace" );
184+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
185+ RootConstantNode, 3 );
115186
116187 if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 4 ))
117188 Constants.Num32BitValues = *Val;
118189 else
119- return reportError (Ctx, " Invalid value for Num32BitValues" );
190+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootConstantNode" ,
191+ RootConstantNode, 4 );
120192
121193 RSD.ParametersContainer .addParameter (Header, Constants);
122194
@@ -154,18 +226,21 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
154226 if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 1 ))
155227 Header.ShaderVisibility = *Val;
156228 else
157- return reportError (Ctx, " Invalid value for ShaderVisibility" );
229+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
230+ RootDescriptorNode, 1 );
158231
159232 dxbc::RTS0::v2::RootDescriptor Descriptor;
160233 if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 2 ))
161234 Descriptor.ShaderRegister = *Val;
162235 else
163- return reportError (Ctx, " Invalid value for ShaderRegister" );
236+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
237+ RootDescriptorNode, 2 );
164238
165239 if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 3 ))
166240 Descriptor.RegisterSpace = *Val;
167241 else
168- return reportError (Ctx, " Invalid value for RegisterSpace" );
242+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
243+ RootDescriptorNode, 3 );
169244
170245 if (RSD.Version == 1 ) {
171246 RSD.ParametersContainer .addParameter (Header, Descriptor);
@@ -176,7 +251,8 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
176251 if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 4 ))
177252 Descriptor.Flags = *Val;
178253 else
179- return reportError (Ctx, " Invalid value for Root Descriptor Flags" );
254+ return reportInvalidTypeError<ConstantInt>(Ctx, " RootDescriptorNode" ,
255+ RootDescriptorNode, 4 );
180256
181257 RSD.ParametersContainer .addParameter (Header, Descriptor);
182258 return false ;
@@ -196,7 +272,8 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
196272 extractMdStringValue (RangeDescriptorNode, 0 );
197273
198274 if (!ElementText.has_value ())
199- return reportError (Ctx, " Descriptor Range, first element is not a string." );
275+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
276+ RangeDescriptorNode, 0 );
200277
201278 Range.RangeType =
202279 StringSwitch<uint32_t >(*ElementText)
@@ -213,28 +290,32 @@ static bool parseDescriptorRange(LLVMContext *Ctx,
213290 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 1 ))
214291 Range.NumDescriptors = *Val;
215292 else
216- return reportError (Ctx, " Invalid value for Number of Descriptor in Range" );
293+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
294+ RangeDescriptorNode, 1 );
217295
218296 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 2 ))
219297 Range.BaseShaderRegister = *Val;
220298 else
221- return reportError (Ctx, " Invalid value for BaseShaderRegister" );
299+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
300+ RangeDescriptorNode, 2 );
222301
223302 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 3 ))
224303 Range.RegisterSpace = *Val;
225304 else
226- return reportError (Ctx, " Invalid value for RegisterSpace" );
305+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
306+ RangeDescriptorNode, 3 );
227307
228308 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 4 ))
229309 Range.OffsetInDescriptorsFromTableStart = *Val;
230310 else
231- return reportError (Ctx,
232- " Invalid value for OffsetInDescriptorsFromTableStart " );
311+ return reportInvalidTypeError<MDString> (Ctx, " RangeDescriptorNode " ,
312+ RangeDescriptorNode, 4 );
233313
234314 if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 5 ))
235315 Range.Flags = *Val;
236316 else
237- return reportError (Ctx, " Invalid value for Descriptor Range Flags" );
317+ return reportInvalidTypeError<MDString>(Ctx, " RangeDescriptorNode" ,
318+ RangeDescriptorNode, 5 );
238319
239320 Table.Ranges .push_back (Range);
240321 return false ;
@@ -251,7 +332,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
251332 if (std::optional<uint32_t > Val = extractMdIntValue (DescriptorTableNode, 1 ))
252333 Header.ShaderVisibility = *Val;
253334 else
254- return reportError (Ctx, " Invalid value for ShaderVisibility" );
335+ return reportInvalidTypeError<MDString>(Ctx, " DescriptorTableNode" ,
336+ DescriptorTableNode, 1 );
255337
256338 mcdxbc::DescriptorTable Table;
257339 Header.ParameterType =
@@ -260,7 +342,8 @@ static bool parseDescriptorTable(LLVMContext *Ctx,
260342 for (unsigned int I = 2 ; I < NumOperands; I++) {
261343 MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand (I));
262344 if (Element == nullptr )
263- return reportError (Ctx, " Missing Root Element Metadata Node." );
345+ return reportInvalidTypeError<MDNode>(Ctx, " DescriptorTableNode" ,
346+ DescriptorTableNode, I);
264347
265348 if (parseDescriptorRange (Ctx, RSD, Table, Element))
266349 return true ;
0 commit comments