15
15
#include " llvm/Frontend/HLSL/RootSignatureValidations.h"
16
16
#include " llvm/IR/IRBuilder.h"
17
17
#include " llvm/IR/Metadata.h"
18
+ #include " llvm/Support/Error.h"
18
19
#include " llvm/Support/ScopedPrinter.h"
20
+ #include < cstdint>
19
21
20
22
using namespace llvm ;
21
23
@@ -26,6 +28,8 @@ namespace rootsig {
26
28
char GenericRSMetadataError::ID;
27
29
char InvalidRSMetadataFormat::ID;
28
30
char InvalidRSMetadataValue::ID;
31
+ char TableSamplerMixinError::ID;
32
+ char TableRegisterOverflowError::ID;
29
33
template <typename T> char RootSignatureValidationError<T>::ID;
30
34
31
35
static std::optional<uint32_t > extractMdIntValue (MDNode *Node,
@@ -525,6 +529,83 @@ Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
525
529
llvm_unreachable (" Unhandled RootSignatureElementKind enum." );
526
530
}
527
531
532
+ Error validateDescriptorTableSamplerMixin (mcdxbc::DescriptorTable Table,
533
+ uint32_t Location) {
534
+ bool HasSampler = false ;
535
+ bool HasOtherRangeType = false ;
536
+ dxbc::DescriptorRangeType OtherRangeType;
537
+
538
+ for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges ) {
539
+ dxbc::DescriptorRangeType RangeType =
540
+ static_cast <dxbc::DescriptorRangeType>(Range.RangeType );
541
+
542
+ if (RangeType == dxbc::DescriptorRangeType::Sampler) {
543
+ HasSampler = true ;
544
+ } else {
545
+ HasOtherRangeType = true ;
546
+ OtherRangeType = RangeType;
547
+ }
548
+ }
549
+
550
+ // Samplers cannot be mixed with other resources in a descriptor table.
551
+ if (HasSampler && HasOtherRangeType)
552
+ return make_error<TableSamplerMixinError>(OtherRangeType, Location);
553
+ return Error::success ();
554
+ }
555
+
556
+ /* * This validation logic was extracted from the DXC codebase
557
+ * https://github.com/microsoft/DirectXShaderCompiler/blob/7a1b1df9b50a8350a63756720e85196e0285e664/lib/DxilRootSignature/DxilRootSignatureValidator.cpp#L205
558
+ *
559
+ * It checks if the registers in a descriptor table are overflowing, meaning,
560
+ * they are trying to bind a register larger than MAX_UINT.
561
+ * This will usually happen when the descriptor table defined a range after an
562
+ * unbounded range, which would lead to an overflow in the register;
563
+ * Or if trying append a bunch or really large ranges.
564
+ **/
565
+ Error validateDescriptorTableRegisterOverflow (mcdxbc::DescriptorTable Table,
566
+ uint32_t Location) {
567
+ uint64_t AppendingRegister = 0 ;
568
+
569
+ for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges ) {
570
+
571
+ dxbc::DescriptorRangeType RangeType =
572
+ static_cast <dxbc::DescriptorRangeType>(Range.RangeType );
573
+
574
+ uint64_t Register = AppendingRegister;
575
+
576
+ // Checks if the current register should be appended to the previous range.
577
+ if (Range.OffsetInDescriptorsFromTableStart != ~0U )
578
+ Register = Range.OffsetInDescriptorsFromTableStart ;
579
+
580
+ // Check for overflow in the register value.
581
+ if (Register > ~0U )
582
+ return make_error<TableRegisterOverflowError>(
583
+ RangeType, Range.BaseShaderRegister , Range.RegisterSpace );
584
+ // Is the current range unbounded?
585
+ if (Range.NumDescriptors == ~0U ) {
586
+ // No ranges should be appended to an unbounded range.
587
+ AppendingRegister = (uint64_t )~0U + (uint64_t )1ULL ;
588
+ } else {
589
+ // Is the defined range, overflowing?
590
+ uint64_t UpperBound = (uint64_t )Range.BaseShaderRegister +
591
+ (uint64_t )Range.NumDescriptors - (uint64_t )1U ;
592
+ if (UpperBound > ~0U )
593
+ return make_error<TableRegisterOverflowError>(
594
+ RangeType, Range.BaseShaderRegister , Range.RegisterSpace );
595
+
596
+ // If we append this range, will it overflow?
597
+ uint64_t AppendingUpperBound =
598
+ (uint64_t )Register + (uint64_t )Range.NumDescriptors - (uint64_t )1U ;
599
+ if (AppendingUpperBound > ~0U )
600
+ return make_error<TableRegisterOverflowError>(
601
+ RangeType, Range.BaseShaderRegister , Range.RegisterSpace );
602
+ AppendingRegister = Register + Range.NumDescriptors ;
603
+ }
604
+ }
605
+
606
+ return Error::success ();
607
+ }
608
+
528
609
Error MetadataParser::validateRootSignature (
529
610
const mcdxbc::RootSignatureDesc &RSD) {
530
611
Error DeferredErrs = Error::success ();
@@ -609,6 +690,16 @@ Error MetadataParser::validateRootSignature(
609
690
joinErrors (std::move (DeferredErrs),
610
691
make_error<RootSignatureValidationError<uint32_t >>(
611
692
" DescriptorFlag" , Range.Flags ));
693
+
694
+ if (Error Err =
695
+ validateDescriptorTableSamplerMixin (Table, Info.Location )) {
696
+ DeferredErrs = joinErrors (std::move (DeferredErrs), std::move (Err));
697
+ }
698
+
699
+ if (Error Err =
700
+ validateDescriptorTableRegisterOverflow (Table, Info.Location )) {
701
+ DeferredErrs = joinErrors (std::move (DeferredErrs), std::move (Err));
702
+ }
612
703
}
613
704
break ;
614
705
}
0 commit comments