@@ -174,6 +174,93 @@ static bool parseRootDescriptors(LLVMContext *Ctx,
174174 return false ;
175175}
176176
177+ static bool parseDescriptorRange (LLVMContext *Ctx,
178+ mcdxbc::DescriptorTable &Table,
179+ MDNode *RangeDescriptorNode) {
180+
181+ if (RangeDescriptorNode->getNumOperands () != 6 )
182+ return reportError (Ctx, " Invalid format for Descriptor Range" );
183+
184+ dxbc::RTS0::v2::DescriptorRange Range;
185+
186+ std::optional<StringRef> ElementText =
187+ extractMdStringValue (RangeDescriptorNode, 0 );
188+
189+ if (!ElementText.has_value ())
190+ return reportError (Ctx, " Descriptor Range, first element is not a string." );
191+
192+ Range.RangeType =
193+ StringSwitch<uint32_t >(*ElementText)
194+ .Case (" CBV" , llvm::to_underlying (dxbc::DescriptorRangeType::CBV))
195+ .Case (" SRV" , llvm::to_underlying (dxbc::DescriptorRangeType::SRV))
196+ .Case (" UAV" , llvm::to_underlying (dxbc::DescriptorRangeType::UAV))
197+ .Case (" Sampler" ,
198+ llvm::to_underlying (dxbc::DescriptorRangeType::Sampler))
199+ .Default (~0U );
200+
201+ if (Range.RangeType == ~0U )
202+ return reportError (Ctx, " Invalid Descriptor Range type: " + *ElementText);
203+
204+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 1 ))
205+ Range.NumDescriptors = *Val;
206+ else
207+ return reportError (Ctx, " Invalid value for Number of Descriptor in Range" );
208+
209+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 2 ))
210+ Range.BaseShaderRegister = *Val;
211+ else
212+ return reportError (Ctx, " Invalid value for BaseShaderRegister" );
213+
214+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 3 ))
215+ Range.RegisterSpace = *Val;
216+ else
217+ return reportError (Ctx, " Invalid value for RegisterSpace" );
218+
219+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 4 ))
220+ Range.OffsetInDescriptorsFromTableStart = *Val;
221+ else
222+ return reportError (Ctx,
223+ " Invalid value for OffsetInDescriptorsFromTableStart" );
224+
225+ if (std::optional<uint32_t > Val = extractMdIntValue (RangeDescriptorNode, 5 ))
226+ Range.Flags = *Val;
227+ else
228+ return reportError (Ctx, " Invalid value for Descriptor Range Flags" );
229+
230+ Table.Ranges .push_back (Range);
231+ return false ;
232+ }
233+
234+ static bool parseDescriptorTable (LLVMContext *Ctx,
235+ mcdxbc::RootSignatureDesc &RSD,
236+ MDNode *DescriptorTableNode) {
237+ const unsigned int NumOperands = DescriptorTableNode->getNumOperands ();
238+ if (NumOperands < 2 )
239+ return reportError (Ctx, " Invalid format for Descriptor Table" );
240+
241+ dxbc::RTS0::v1::RootParameterHeader Header;
242+ if (std::optional<uint32_t > Val = extractMdIntValue (DescriptorTableNode, 1 ))
243+ Header.ShaderVisibility = *Val;
244+ else
245+ return reportError (Ctx, " Invalid value for ShaderVisibility" );
246+
247+ mcdxbc::DescriptorTable Table;
248+ Header.ParameterType =
249+ llvm::to_underlying (dxbc::RootParameterType::DescriptorTable);
250+
251+ for (unsigned int I = 2 ; I < NumOperands; I++) {
252+ MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand (I));
253+ if (Element == nullptr )
254+ return reportError (Ctx, " Missing Root Element Metadata Node." );
255+
256+ if (parseDescriptorRange (Ctx, Table, Element))
257+ return true ;
258+ }
259+
260+ RSD.ParametersContainer .addParameter (Header, Table);
261+ return false ;
262+ }
263+
177264static bool parseRootSignatureElement (LLVMContext *Ctx,
178265 mcdxbc::RootSignatureDesc &RSD,
179266 MDNode *Element) {
@@ -188,6 +275,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
188275 .Case (" RootCBV" , RootSignatureElementKind::CBV)
189276 .Case (" RootSRV" , RootSignatureElementKind::SRV)
190277 .Case (" RootUAV" , RootSignatureElementKind::UAV)
278+ .Case (" DescriptorTable" , RootSignatureElementKind::DescriptorTable)
191279 .Default (RootSignatureElementKind::Error);
192280
193281 switch (ElementKind) {
@@ -200,6 +288,8 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
200288 case RootSignatureElementKind::SRV:
201289 case RootSignatureElementKind::UAV:
202290 return parseRootDescriptors (Ctx, RSD, Element, ElementKind);
291+ case RootSignatureElementKind::DescriptorTable:
292+ return parseDescriptorTable (Ctx, RSD, Element);
203293 case RootSignatureElementKind::Error:
204294 return reportError (Ctx, " Invalid Root Signature Element: " + *ElementText);
205295 }
@@ -241,6 +331,81 @@ static bool verifyRegisterSpace(uint32_t RegisterSpace) {
241331
242332static bool verifyDescriptorFlag (uint32_t Flags) { return (Flags & ~0xE ) == 0 ; }
243333
334+ static bool verifyRangeType (uint32_t Type) {
335+ switch (Type) {
336+ case llvm::to_underlying (dxbc::DescriptorRangeType::CBV):
337+ case llvm::to_underlying (dxbc::DescriptorRangeType::SRV):
338+ case llvm::to_underlying (dxbc::DescriptorRangeType::UAV):
339+ case llvm::to_underlying (dxbc::DescriptorRangeType::Sampler):
340+ return true ;
341+ };
342+
343+ return false ;
344+ }
345+
346+ static bool verifyDescriptorRangeFlag (uint32_t Version, uint32_t Type,
347+ uint32_t FlagsVal) {
348+ using FlagT = dxbc::DescriptorRangeFlag;
349+ FlagT Flags = FlagT (FlagsVal);
350+
351+ const bool IsSampler =
352+ (Type == llvm::to_underlying (dxbc::DescriptorRangeType::Sampler));
353+
354+ if (Version == 1 ) {
355+ // Since the metadata is unversioned, we expect to explicitly see the values
356+ // that map to the version 1 behaviour here.
357+ if (IsSampler)
358+ return Flags == FlagT::DESCRIPTORS_VOLATILE;
359+ return Flags == (FlagT::DATA_VOLATILE | FlagT::DESCRIPTORS_VOLATILE);
360+ }
361+
362+ // The data-specific flags are mutually exclusive.
363+ FlagT DataFlags = FlagT::DATA_VOLATILE | FlagT::DATA_STATIC |
364+ FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
365+
366+ if (popcount (llvm::to_underlying (Flags & DataFlags)) > 1 )
367+ return false ;
368+
369+ // The descriptor-specific flags are mutually exclusive.
370+ FlagT DescriptorFlags =
371+ FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS |
372+ FlagT::DESCRIPTORS_VOLATILE;
373+ if (popcount (llvm::to_underlying (Flags & DescriptorFlags)) > 1 )
374+ return false ;
375+
376+ // For volatile descriptors, DATA_STATIC is never valid.
377+ if ((Flags & FlagT::DESCRIPTORS_VOLATILE) == FlagT::DESCRIPTORS_VOLATILE) {
378+ FlagT Mask = FlagT::DESCRIPTORS_VOLATILE;
379+ if (!IsSampler) {
380+ Mask |= FlagT::DATA_VOLATILE;
381+ Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
382+ }
383+ return (Flags & ~Mask) == FlagT::NONE;
384+ }
385+
386+ // For "STATIC_KEEPING_BUFFER_BOUNDS_CHECKS" descriptors,
387+ // the other data-specific flags may all be set.
388+ if ((Flags & FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) ==
389+ FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS) {
390+ FlagT Mask = FlagT::DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS;
391+ if (!IsSampler) {
392+ Mask |= FlagT::DATA_VOLATILE;
393+ Mask |= FlagT::DATA_STATIC;
394+ Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
395+ }
396+ return (Flags & ~Mask) == FlagT::NONE;
397+ }
398+
399+ // When no descriptor flag is set, any data flag is allowed.
400+ FlagT Mask = FlagT::NONE;
401+ if (!IsSampler) {
402+ Mask |= FlagT::DATA_VOLATILE;
403+ Mask |= FlagT::DATA_STATIC;
404+ Mask |= FlagT::DATA_STATIC_WHILE_SET_AT_EXECUTE;
405+ }
406+ return (Flags & ~Mask) == FlagT::NONE;
407+ }
408+
244409static bool validate (LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
245410
246411 if (!verifyVersion (RSD.Version )) {
@@ -275,7 +440,23 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
275440
276441 if (RSD.Version > 1 ) {
277442 if (!verifyDescriptorFlag (Descriptor.Flags ))
278- return reportValueError (Ctx, " DescriptorFlag" , Descriptor.Flags );
443+ return reportValueError (Ctx, " DescriptorRangeFlag" , Descriptor.Flags );
444+ }
445+ break ;
446+ }
447+ case llvm::to_underlying (dxbc::RootParameterType::DescriptorTable): {
448+ const mcdxbc::DescriptorTable &Table =
449+ RSD.ParametersContainer .getDescriptorTable (Info.Location );
450+ for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
451+ if (!verifyRangeType (Range.RangeType ))
452+ return reportValueError (Ctx, " RangeType" , Range.RangeType );
453+
454+ if (!verifyRegisterSpace (Range.RegisterSpace ))
455+ return reportValueError (Ctx, " RegisterSpace" , Range.RegisterSpace );
456+
457+ if (!verifyDescriptorRangeFlag (RSD.Version , Range.RangeType ,
458+ Range.Flags ))
459+ return reportValueError (Ctx, " DescriptorFlag" , Range.Flags );
279460 }
280461 break ;
281462 }
@@ -388,67 +569,67 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
388569
389570 OS << " Root Signature Definitions"
390571 << " \n " ;
391- uint8_t Space = 0 ;
392572 for (const Function &F : M) {
393573 auto It = RSDMap.find (&F);
394574 if (It == RSDMap.end ())
395575 continue ;
396576 const auto &RS = It->second ;
397577 OS << " Definition for '" << F.getName () << " ':\n " ;
398-
399578 // start root signature header
400- Space++;
401- OS << indent (Space) << " Flags: " << format_hex (RS.Flags , 8 ) << " \n " ;
402- OS << indent (Space) << " Version: " << RS.Version << " \n " ;
403- OS << indent (Space) << " RootParametersOffset: " << RS.RootParameterOffset
404- << " \n " ;
405- OS << indent (Space) << " NumParameters: " << RS.ParametersContainer .size ()
406- << " \n " ;
407- Space++;
579+ OS << " Flags: " << format_hex (RS.Flags , 8 ) << " \n "
580+ << " Version: " << RS.Version << " \n "
581+ << " RootParametersOffset: " << RS.RootParameterOffset << " \n "
582+ << " NumParameters: " << RS.ParametersContainer .size () << " \n " ;
408583 for (size_t I = 0 ; I < RS.ParametersContainer .size (); I++) {
409584 const auto &[Type, Loc] =
410585 RS.ParametersContainer .getTypeAndLocForParameter (I);
411586 const dxbc::RTS0::v1::RootParameterHeader Header =
412587 RS.ParametersContainer .getHeader (I);
413588
414- OS << indent (Space) << " - Parameter Type: " << Type << " \n " ;
415- OS << indent (Space + 2 )
416- << " Shader Visibility: " << Header.ShaderVisibility << " \n " ;
589+ OS << " - Parameter Type: " << Type << " \n "
590+ << " Shader Visibility: " << Header.ShaderVisibility << " \n " ;
417591
418592 switch (Type) {
419593 case llvm::to_underlying (dxbc::RootParameterType::Constants32Bit): {
420594 const dxbc::RTS0::v1::RootConstants &Constants =
421595 RS.ParametersContainer .getConstant (Loc);
422- OS << indent (Space + 2 ) << " Register Space: " << Constants.RegisterSpace
423- << " \n " ;
424- OS << indent (Space + 2 )
425- << " Shader Register: " << Constants.ShaderRegister << " \n " ;
426- OS << indent (Space + 2 )
427- << " Num 32 Bit Values: " << Constants.Num32BitValues << " \n " ;
596+ OS << " Register Space: " << Constants.RegisterSpace << " \n "
597+ << " Shader Register: " << Constants.ShaderRegister << " \n "
598+ << " Num 32 Bit Values: " << Constants.Num32BitValues << " \n " ;
428599 break ;
429600 }
430601 case llvm::to_underlying (dxbc::RootParameterType::CBV):
431602 case llvm::to_underlying (dxbc::RootParameterType::UAV):
432603 case llvm::to_underlying (dxbc::RootParameterType::SRV): {
433604 const dxbc::RTS0::v2::RootDescriptor &Descriptor =
434605 RS.ParametersContainer .getRootDescriptor (Loc);
435- OS << indent (Space + 2 )
436- << " Register Space: " << Descriptor.RegisterSpace << " \n " ;
437- OS << indent (Space + 2 )
438- << " Shader Register: " << Descriptor.ShaderRegister << " \n " ;
606+ OS << " Register Space: " << Descriptor.RegisterSpace << " \n "
607+ << " Shader Register: " << Descriptor.ShaderRegister << " \n " ;
439608 if (RS.Version > 1 )
440- OS << indent (Space + 2 ) << " Flags: " << Descriptor.Flags << " \n " ;
609+ OS << " Flags: " << Descriptor.Flags << " \n " ;
610+ break ;
611+ }
612+ case llvm::to_underlying (dxbc::RootParameterType::DescriptorTable): {
613+ const mcdxbc::DescriptorTable &Table =
614+ RS.ParametersContainer .getDescriptorTable (Loc);
615+ OS << " NumRanges: " << Table.Ranges .size () << " \n " ;
616+
617+ for (const dxbc::RTS0::v2::DescriptorRange Range : Table) {
618+ OS << " - Range Type: " << Range.RangeType << " \n "
619+ << " Register Space: " << Range.RegisterSpace << " \n "
620+ << " Base Shader Register: " << Range.BaseShaderRegister << " \n "
621+ << " Num Descriptors: " << Range.NumDescriptors << " \n "
622+ << " Offset In Descriptors From Table Start: "
623+ << Range.OffsetInDescriptorsFromTableStart << " \n " ;
624+ if (RS.Version > 1 )
625+ OS << " Flags: " << Range.Flags << " \n " ;
626+ }
441627 break ;
442628 }
443629 }
444- Space--;
445630 }
446- OS << indent (Space) << " NumStaticSamplers: " << 0 << " \n " ;
447- OS << indent (Space) << " StaticSamplersOffset: " << RS.StaticSamplersOffset
448- << " \n " ;
449-
450- Space--;
451- // end root signature header
631+ OS << " NumStaticSamplers: " << 0 << " \n " ;
632+ OS << " StaticSamplersOffset: " << RS.StaticSamplersOffset << " \n " ;
452633 }
453634 return PreservedAnalyses::all ();
454635}
0 commit comments