Skip to content

Conversation

joaosaffran
Copy link
Contributor

@joaosaffran joaosaffran commented Jul 2, 2025

DXC checks if registers are correctly bound to root signature descriptors. This implements the same check.
closes: #126645

Copy link

github-actions bot commented Jul 2, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@joaosaffran joaosaffran force-pushed the validation/check-descriptors-are-bound branch from d655039 to d8a91e2 Compare July 4, 2025 19:25
@joaosaffran joaosaffran changed the title [DirectX][Draft] validate registers are bound to root signature [DirectX] Validate registers are bound to root signature Jul 4, 2025
@joaosaffran joaosaffran changed the base branch from users/joaosaffran/146783 to main July 4, 2025 19:39
@joaosaffran joaosaffran changed the base branch from main to users/joaosaffran/146783 July 4, 2025 19:40
@joaosaffran joaosaffran force-pushed the validation/check-descriptors-are-bound branch from d8a91e2 to 28350b2 Compare July 5, 2025 00:35
@joaosaffran joaosaffran marked this pull request as ready for review July 5, 2025 00:40
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:DirectX HLSL HLSL Language Support labels Jul 5, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 5, 2025

@llvm/pr-subscribers-clang

Author: None (joaosaffran)

Changes

DXC checks if registers are correctly bound to root signature descriptors. This implements the same check.

Closes: 126645


Full diff: https://github.com/llvm/llvm-project/pull/146785.diff

7 Files Affected:

  • (added) clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl (+35)
  • (added) clang/test/SemaHLSL/RootSignature-Validation.hlsl (+33)
  • (modified) llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp (+133-3)
  • (modified) llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h (+119)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+1)
diff --git a/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
new file mode 100644
index 0000000000000..b590ed67e7085
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
@@ -0,0 +1,35 @@
+// RUN: not %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 | FileCheck %s
+
+// CHECK: error: register cbuffer (space=665, register=3) is not defined in Root Signature
+// CHECK: error: register srv (space=0, register=0) is not defined in Root Signature
+// CHECK: error: register uav (space=0, register=4294967295) is not defined in Root Signature
+
+
+#define ROOT_SIGNATURE \
+    "CBV(b3, space=666, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_VERTEX), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space665) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0, space0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u4294967295);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space0);
+
+
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint id : SV_GroupID)
+{
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
+}
diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
new file mode 100644
index 0000000000000..5a7f5baf00619
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
@@ -0,0 +1,33 @@
+// RUN: %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 
+
+// expected-no-diagnostics
+
+
+#define ROOT_SIGNATURE \
+    "CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_VERTEX), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space1) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0, space0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u4294967294);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space0);
+
+
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint id : SV_GroupID)
+{
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
+}
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index 398dcbb8d1737..a52a04323514c 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILPostOptimizationValidation.h"
+#include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallString.h"
@@ -84,8 +85,60 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
   }
 }
 
+static void reportRegNotBound(Module &M, Twine Type,
+                              ResourceInfo::ResourceBinding Binding) {
+  SmallString<128> Message;
+  raw_svector_ostream OS(Message);
+  OS << "register " << Type << " (space=" << Binding.Space
+     << ", register=" << Binding.LowerBound << ")"
+     << " is not defined in Root Signature";
+  M.getContext().diagnose(DiagnosticInfoGeneric(Message));
+}
+
+static dxbc::ShaderVisibility
+tripleToVisibility(llvm::Triple::EnvironmentType ET) {
+  assert((ET == Triple::Pixel || ET == Triple::Vertex ||
+          ET == Triple::Geometry || ET == Triple::Hull ||
+          ET == Triple::Domain || ET == Triple::Mesh ||
+          ET == Triple::Compute) &&
+         "Invalid Triple to shader stage conversion");
+
+  switch (ET) {
+  case Triple::Pixel:
+    return dxbc::ShaderVisibility::Pixel;
+  case Triple::Vertex:
+    return dxbc::ShaderVisibility::Vertex;
+  case Triple::Geometry:
+    return dxbc::ShaderVisibility::Geometry;
+  case Triple::Hull:
+    return dxbc::ShaderVisibility::Hull;
+  case Triple::Domain:
+    return dxbc::ShaderVisibility::Domain;
+  case Triple::Mesh:
+    return dxbc::ShaderVisibility::Mesh;
+  case Triple::Compute:
+    return dxbc::ShaderVisibility::All;
+  default:
+    llvm_unreachable("Invalid triple to shader stage conversion");
+  }
+}
+
+std::optional<mcdxbc::RootSignatureDesc>
+getRootSignature(RootSignatureBindingInfo &RSBI,
+                 dxil::ModuleMetadataInfo &MMI) {
+  if (MMI.EntryPropertyVec.size() == 0)
+    return std::nullopt;
+  std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
+      RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
+  if (!RootSigDesc)
+    return std::nullopt;
+  return RootSigDesc;
+}
+
 static void reportErrors(Module &M, DXILResourceMap &DRM,
-                         DXILResourceBindingInfo &DRBI) {
+                         DXILResourceBindingInfo &DRBI,
+                         RootSignatureBindingInfo &RSBI,
+                         dxil::ModuleMetadataInfo &MMI) {
   if (DRM.hasInvalidCounterDirection())
     reportInvalidDirection(M, DRM);
 
@@ -94,14 +147,83 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
 
   assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
                                        "DXILResourceImplicitBinding pass");
+
+  if (auto RSD = getRootSignature(RSBI, MMI)) {
+
+    RootSignatureBindingValidation Validation;
+    Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile));
+
+    for (const ResourceInfo &CBuf : DRM.cbuffers()) {
+      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+      if (!Validation.checkCRegBinding(Binding))
+        reportRegNotBound(M, "cbuffer", Binding);
+    }
+
+    for (const ResourceInfo &SRV : DRM.srvs()) {
+      ResourceInfo::ResourceBinding Binding = SRV.getBinding();
+      if (!Validation.checkTRegBinding(Binding))
+        reportRegNotBound(M, "srv", Binding);
+    }
+
+    for (const ResourceInfo &UAV : DRM.uavs()) {
+      ResourceInfo::ResourceBinding Binding = UAV.getBinding();
+      if (!Validation.checkURegBinding(Binding))
+        reportRegNotBound(M, "uav", Binding);
+    }
+
+    for (const ResourceInfo &Sampler : DRM.samplers()) {
+      ResourceInfo::ResourceBinding Binding = Sampler.getBinding();
+      if (!Validation.checkSamplerBinding(Binding))
+        reportRegNotBound(M, "sampler", Binding);
+    }
+  }
 }
 } // namespace
 
+void RootSignatureBindingValidation::addRsBindingInfo(
+    mcdxbc::RootSignatureDesc &RSD, dxbc::ShaderVisibility Visibility) {
+  for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
+    const auto &[Type, Loc] =
+        RSD.ParametersContainer.getTypeAndLocForParameter(I);
+
+    const auto &Header = RSD.ParametersContainer.getHeader(I);
+    switch (Type) {
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case llvm::to_underlying(dxbc::RootParameterType::CBV): {
+      dxbc::RTS0::v2::RootDescriptor Desc =
+          RSD.ParametersContainer.getRootDescriptor(Loc);
+
+      if (Header.ShaderVisibility ==
+              llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+          Header.ShaderVisibility == llvm::to_underlying(Visibility))
+        addRange(Desc, Type);
+      break;
+    }
+    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      const mcdxbc::DescriptorTable &Table =
+          RSD.ParametersContainer.getDescriptorTable(Loc);
+
+      for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
+        if (Header.ShaderVisibility ==
+                llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+            Header.ShaderVisibility == llvm::to_underlying(Visibility))
+          addRange(Range);
+      }
+      break;
+    }
+    }
+  }
+}
+
 PreservedAnalyses
 DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
   DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M);
-  reportErrors(M, DRM, DRBI);
+  RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M);
+  ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M);
+
+  reportErrors(M, DRM, DRBI, RSBI, MMI);
   return PreservedAnalyses::all();
 }
 
@@ -113,7 +235,12 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     DXILResourceBindingInfo &DRBI =
         getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
-    reportErrors(M, DRM, DRBI);
+    RootSignatureBindingInfo &RSBI =
+        getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
+    dxil::ModuleMetadataInfo &MMI =
+        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+
+    reportErrors(M, DRM, DRBI, RSBI, MMI);
     return false;
   }
   StringRef getPassName() const override {
@@ -125,10 +252,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
     AU.addRequired<DXILResourceWrapperPass>();
     AU.addRequired<DXILResourceBindingWrapperPass>();
+    AU.addRequired<RootSignatureAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<DXILResourceWrapperPass>();
     AU.addPreserved<DXILResourceBindingWrapperPass>();
     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
+    AU.addPreserved<RootSignatureAnalysisWrapper>();
   }
 };
 char DXILPostOptimizationValidationLegacy::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
index cb5e624514272..0fa0285425d7e 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
@@ -14,10 +14,129 @@
 #ifndef LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 #define LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 
+#include "DXILRootSignature.h"
+#include "llvm/ADT/IntervalMap.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/PassManager.h"
 
 namespace llvm {
 
+static uint64_t combineUint32ToUint64(uint32_t High, uint32_t Low) {
+  return (static_cast<uint64_t>(High) << 32) | Low;
+}
+
+class RootSignatureBindingValidation {
+  using MapT =
+      llvm::IntervalMap<uint64_t, dxil::ResourceInfo::ResourceBinding,
+                        sizeof(llvm::dxil::ResourceInfo::ResourceBinding),
+                        llvm::IntervalMapInfo<uint64_t>>;
+
+private:
+  MapT::Allocator Allocator;
+  MapT CRegBindingsMap;
+  MapT TRegBindingsMap;
+  MapT URegBindingsMap;
+  MapT SamplersBindingsMap;
+
+  void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) {
+    assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::SRV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) &&
+           "Invalid Type in add Range Method");
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Desc.ShaderRegister;
+    Binding.Space = Desc.RegisterSpace;
+    Binding.Size = 1;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
+    switch (Type) {
+
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    }
+  }
+
+  void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) {
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Range.BaseShaderRegister;
+    Binding.Space = Range.RegisterSpace;
+    Binding.Size = Range.NumDescriptors;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
+    switch (Range.RangeType) {
+    case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+      SamplersBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    }
+  }
+
+public:
+  RootSignatureBindingValidation()
+      : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator),
+        URegBindingsMap(Allocator), SamplersBindingsMap(Allocator) {}
+
+  void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD,
+                        dxbc::ShaderVisibility Visibility);
+
+  bool checkCRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return CRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return TRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return URegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkSamplerBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return SamplersBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+};
+
 class DXILPostOptimizationValidation
     : public PassInfoMixin<DXILPostOptimizationValidation> {
 public:
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
index 9d89dbdd9107b..053721de1eb1f 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
@@ -13,7 +13,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 1 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"Sampler", i32 0, i32 1, i32 0, i32 -1, i32 1 }
+!6 = !{ !"Sampler", i32 1, i32 1, i32 0, i32 -1, i32 1 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 3 }
 
 
@@ -33,7 +33,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:             RangesOffset:    44
 ; DXC-NEXT:             Ranges:
 ; DXC-NEXT:               - RangeType:       3
-; DXC-NEXT:                 NumDescriptors:  0
+; DXC-NEXT:                 NumDescriptors:  1
 ; DXC-NEXT:                 BaseShaderRegister: 1
 ; DXC-NEXT:                 RegisterSpace:   0
 ; DXC-NEXT:                 OffsetInDescriptorsFromTableStart: 4294967295
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
index b516d66180247..8e9b4b43b11a6 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
@@ -16,7 +16,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 2 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"SRV", i32 0, i32 1, i32 0, i32 -1, i32 4 }
+!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 4 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
 
 ; DXC:  - Name:            RTS0
@@ -35,7 +35,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:            RangesOffset:    44
 ; DXC-NEXT:            Ranges:
 ; DXC-NEXT:              - RangeType:       0
-; DXC-NEXT:                NumDescriptors:  0
+; DXC-NEXT:                NumDescriptors:  1
 ; DXC-NEXT:                BaseShaderRegister: 1
 ; DXC-NEXT:                RegisterSpace:   0
 ; DXC-NEXT:                OffsetInDescriptorsFromTableStart: 4294967295
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 2b29fd30a7a56..8d75249dc6ecb 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -31,6 +31,7 @@
 ; CHECK-NEXT:   DXIL Module Metadata analysis
 ; CHECK-NEXT:   DXIL Shader Flag Analysis
 ; CHECK-NEXT:   DXIL Translate Metadata
+; CHECK-NEXT:   DXIL Root Signature Analysis
 ; CHECK-NEXT:   DXIL Post Optimization Validation
 ; CHECK-NEXT:   DXIL Op Lowering
 ; CHECK-NEXT:   DXIL Prepare Module

@llvmbot
Copy link
Member

llvmbot commented Jul 5, 2025

@llvm/pr-subscribers-hlsl

Author: None (joaosaffran)

Changes

DXC checks if registers are correctly bound to root signature descriptors. This implements the same check.

Closes: 126645


Full diff: https://github.com/llvm/llvm-project/pull/146785.diff

7 Files Affected:

  • (added) clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl (+35)
  • (added) clang/test/SemaHLSL/RootSignature-Validation.hlsl (+33)
  • (modified) llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp (+133-3)
  • (modified) llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h (+119)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+1)
diff --git a/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
new file mode 100644
index 0000000000000..b590ed67e7085
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
@@ -0,0 +1,35 @@
+// RUN: not %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 | FileCheck %s
+
+// CHECK: error: register cbuffer (space=665, register=3) is not defined in Root Signature
+// CHECK: error: register srv (space=0, register=0) is not defined in Root Signature
+// CHECK: error: register uav (space=0, register=4294967295) is not defined in Root Signature
+
+
+#define ROOT_SIGNATURE \
+    "CBV(b3, space=666, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_VERTEX), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space665) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0, space0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u4294967295);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space0);
+
+
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint id : SV_GroupID)
+{
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
+}
diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
new file mode 100644
index 0000000000000..5a7f5baf00619
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
@@ -0,0 +1,33 @@
+// RUN: %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 
+
+// expected-no-diagnostics
+
+
+#define ROOT_SIGNATURE \
+    "CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_VERTEX), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space1) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0, space0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u4294967294);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space0);
+
+
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint id : SV_GroupID)
+{
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
+}
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index 398dcbb8d1737..a52a04323514c 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILPostOptimizationValidation.h"
+#include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallString.h"
@@ -84,8 +85,60 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
   }
 }
 
+static void reportRegNotBound(Module &M, Twine Type,
+                              ResourceInfo::ResourceBinding Binding) {
+  SmallString<128> Message;
+  raw_svector_ostream OS(Message);
+  OS << "register " << Type << " (space=" << Binding.Space
+     << ", register=" << Binding.LowerBound << ")"
+     << " is not defined in Root Signature";
+  M.getContext().diagnose(DiagnosticInfoGeneric(Message));
+}
+
+static dxbc::ShaderVisibility
+tripleToVisibility(llvm::Triple::EnvironmentType ET) {
+  assert((ET == Triple::Pixel || ET == Triple::Vertex ||
+          ET == Triple::Geometry || ET == Triple::Hull ||
+          ET == Triple::Domain || ET == Triple::Mesh ||
+          ET == Triple::Compute) &&
+         "Invalid Triple to shader stage conversion");
+
+  switch (ET) {
+  case Triple::Pixel:
+    return dxbc::ShaderVisibility::Pixel;
+  case Triple::Vertex:
+    return dxbc::ShaderVisibility::Vertex;
+  case Triple::Geometry:
+    return dxbc::ShaderVisibility::Geometry;
+  case Triple::Hull:
+    return dxbc::ShaderVisibility::Hull;
+  case Triple::Domain:
+    return dxbc::ShaderVisibility::Domain;
+  case Triple::Mesh:
+    return dxbc::ShaderVisibility::Mesh;
+  case Triple::Compute:
+    return dxbc::ShaderVisibility::All;
+  default:
+    llvm_unreachable("Invalid triple to shader stage conversion");
+  }
+}
+
+std::optional<mcdxbc::RootSignatureDesc>
+getRootSignature(RootSignatureBindingInfo &RSBI,
+                 dxil::ModuleMetadataInfo &MMI) {
+  if (MMI.EntryPropertyVec.size() == 0)
+    return std::nullopt;
+  std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
+      RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
+  if (!RootSigDesc)
+    return std::nullopt;
+  return RootSigDesc;
+}
+
 static void reportErrors(Module &M, DXILResourceMap &DRM,
-                         DXILResourceBindingInfo &DRBI) {
+                         DXILResourceBindingInfo &DRBI,
+                         RootSignatureBindingInfo &RSBI,
+                         dxil::ModuleMetadataInfo &MMI) {
   if (DRM.hasInvalidCounterDirection())
     reportInvalidDirection(M, DRM);
 
@@ -94,14 +147,83 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
 
   assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
                                        "DXILResourceImplicitBinding pass");
+
+  if (auto RSD = getRootSignature(RSBI, MMI)) {
+
+    RootSignatureBindingValidation Validation;
+    Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile));
+
+    for (const ResourceInfo &CBuf : DRM.cbuffers()) {
+      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+      if (!Validation.checkCRegBinding(Binding))
+        reportRegNotBound(M, "cbuffer", Binding);
+    }
+
+    for (const ResourceInfo &SRV : DRM.srvs()) {
+      ResourceInfo::ResourceBinding Binding = SRV.getBinding();
+      if (!Validation.checkTRegBinding(Binding))
+        reportRegNotBound(M, "srv", Binding);
+    }
+
+    for (const ResourceInfo &UAV : DRM.uavs()) {
+      ResourceInfo::ResourceBinding Binding = UAV.getBinding();
+      if (!Validation.checkURegBinding(Binding))
+        reportRegNotBound(M, "uav", Binding);
+    }
+
+    for (const ResourceInfo &Sampler : DRM.samplers()) {
+      ResourceInfo::ResourceBinding Binding = Sampler.getBinding();
+      if (!Validation.checkSamplerBinding(Binding))
+        reportRegNotBound(M, "sampler", Binding);
+    }
+  }
 }
 } // namespace
 
+void RootSignatureBindingValidation::addRsBindingInfo(
+    mcdxbc::RootSignatureDesc &RSD, dxbc::ShaderVisibility Visibility) {
+  for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
+    const auto &[Type, Loc] =
+        RSD.ParametersContainer.getTypeAndLocForParameter(I);
+
+    const auto &Header = RSD.ParametersContainer.getHeader(I);
+    switch (Type) {
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case llvm::to_underlying(dxbc::RootParameterType::CBV): {
+      dxbc::RTS0::v2::RootDescriptor Desc =
+          RSD.ParametersContainer.getRootDescriptor(Loc);
+
+      if (Header.ShaderVisibility ==
+              llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+          Header.ShaderVisibility == llvm::to_underlying(Visibility))
+        addRange(Desc, Type);
+      break;
+    }
+    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      const mcdxbc::DescriptorTable &Table =
+          RSD.ParametersContainer.getDescriptorTable(Loc);
+
+      for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
+        if (Header.ShaderVisibility ==
+                llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+            Header.ShaderVisibility == llvm::to_underlying(Visibility))
+          addRange(Range);
+      }
+      break;
+    }
+    }
+  }
+}
+
 PreservedAnalyses
 DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
   DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M);
-  reportErrors(M, DRM, DRBI);
+  RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M);
+  ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M);
+
+  reportErrors(M, DRM, DRBI, RSBI, MMI);
   return PreservedAnalyses::all();
 }
 
@@ -113,7 +235,12 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     DXILResourceBindingInfo &DRBI =
         getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
-    reportErrors(M, DRM, DRBI);
+    RootSignatureBindingInfo &RSBI =
+        getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
+    dxil::ModuleMetadataInfo &MMI =
+        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+
+    reportErrors(M, DRM, DRBI, RSBI, MMI);
     return false;
   }
   StringRef getPassName() const override {
@@ -125,10 +252,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
     AU.addRequired<DXILResourceWrapperPass>();
     AU.addRequired<DXILResourceBindingWrapperPass>();
+    AU.addRequired<RootSignatureAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<DXILResourceWrapperPass>();
     AU.addPreserved<DXILResourceBindingWrapperPass>();
     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
+    AU.addPreserved<RootSignatureAnalysisWrapper>();
   }
 };
 char DXILPostOptimizationValidationLegacy::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
index cb5e624514272..0fa0285425d7e 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
@@ -14,10 +14,129 @@
 #ifndef LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 #define LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 
+#include "DXILRootSignature.h"
+#include "llvm/ADT/IntervalMap.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/PassManager.h"
 
 namespace llvm {
 
+static uint64_t combineUint32ToUint64(uint32_t High, uint32_t Low) {
+  return (static_cast<uint64_t>(High) << 32) | Low;
+}
+
+class RootSignatureBindingValidation {
+  using MapT =
+      llvm::IntervalMap<uint64_t, dxil::ResourceInfo::ResourceBinding,
+                        sizeof(llvm::dxil::ResourceInfo::ResourceBinding),
+                        llvm::IntervalMapInfo<uint64_t>>;
+
+private:
+  MapT::Allocator Allocator;
+  MapT CRegBindingsMap;
+  MapT TRegBindingsMap;
+  MapT URegBindingsMap;
+  MapT SamplersBindingsMap;
+
+  void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) {
+    assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::SRV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) &&
+           "Invalid Type in add Range Method");
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Desc.ShaderRegister;
+    Binding.Space = Desc.RegisterSpace;
+    Binding.Size = 1;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
+    switch (Type) {
+
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    }
+  }
+
+  void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) {
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Range.BaseShaderRegister;
+    Binding.Space = Range.RegisterSpace;
+    Binding.Size = Range.NumDescriptors;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
+    switch (Range.RangeType) {
+    case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+      SamplersBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    }
+  }
+
+public:
+  RootSignatureBindingValidation()
+      : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator),
+        URegBindingsMap(Allocator), SamplersBindingsMap(Allocator) {}
+
+  void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD,
+                        dxbc::ShaderVisibility Visibility);
+
+  bool checkCRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return CRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return TRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return URegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkSamplerBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return SamplersBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+};
+
 class DXILPostOptimizationValidation
     : public PassInfoMixin<DXILPostOptimizationValidation> {
 public:
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
index 9d89dbdd9107b..053721de1eb1f 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
@@ -13,7 +13,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 1 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"Sampler", i32 0, i32 1, i32 0, i32 -1, i32 1 }
+!6 = !{ !"Sampler", i32 1, i32 1, i32 0, i32 -1, i32 1 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 3 }
 
 
@@ -33,7 +33,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:             RangesOffset:    44
 ; DXC-NEXT:             Ranges:
 ; DXC-NEXT:               - RangeType:       3
-; DXC-NEXT:                 NumDescriptors:  0
+; DXC-NEXT:                 NumDescriptors:  1
 ; DXC-NEXT:                 BaseShaderRegister: 1
 ; DXC-NEXT:                 RegisterSpace:   0
 ; DXC-NEXT:                 OffsetInDescriptorsFromTableStart: 4294967295
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
index b516d66180247..8e9b4b43b11a6 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
@@ -16,7 +16,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 2 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"SRV", i32 0, i32 1, i32 0, i32 -1, i32 4 }
+!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 4 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
 
 ; DXC:  - Name:            RTS0
@@ -35,7 +35,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:            RangesOffset:    44
 ; DXC-NEXT:            Ranges:
 ; DXC-NEXT:              - RangeType:       0
-; DXC-NEXT:                NumDescriptors:  0
+; DXC-NEXT:                NumDescriptors:  1
 ; DXC-NEXT:                BaseShaderRegister: 1
 ; DXC-NEXT:                RegisterSpace:   0
 ; DXC-NEXT:                OffsetInDescriptorsFromTableStart: 4294967295
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 2b29fd30a7a56..8d75249dc6ecb 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -31,6 +31,7 @@
 ; CHECK-NEXT:   DXIL Module Metadata analysis
 ; CHECK-NEXT:   DXIL Shader Flag Analysis
 ; CHECK-NEXT:   DXIL Translate Metadata
+; CHECK-NEXT:   DXIL Root Signature Analysis
 ; CHECK-NEXT:   DXIL Post Optimization Validation
 ; CHECK-NEXT:   DXIL Op Lowering
 ; CHECK-NEXT:   DXIL Prepare Module

@llvmbot
Copy link
Member

llvmbot commented Jul 5, 2025

@llvm/pr-subscribers-backend-directx

Author: None (joaosaffran)

Changes

DXC checks if registers are correctly bound to root signature descriptors. This implements the same check.

Closes: 126645


Full diff: https://github.com/llvm/llvm-project/pull/146785.diff

7 Files Affected:

  • (added) clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl (+35)
  • (added) clang/test/SemaHLSL/RootSignature-Validation.hlsl (+33)
  • (modified) llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp (+133-3)
  • (modified) llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h (+119)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll (+2-2)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+1)
diff --git a/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
new file mode 100644
index 0000000000000..b590ed67e7085
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation-Fail.hlsl
@@ -0,0 +1,35 @@
+// RUN: not %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 | FileCheck %s
+
+// CHECK: error: register cbuffer (space=665, register=3) is not defined in Root Signature
+// CHECK: error: register srv (space=0, register=0) is not defined in Root Signature
+// CHECK: error: register uav (space=0, register=4294967295) is not defined in Root Signature
+
+
+#define ROOT_SIGNATURE \
+    "CBV(b3, space=666, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_VERTEX), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space665) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0, space0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u4294967295);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space0);
+
+
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint id : SV_GroupID)
+{
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
+}
diff --git a/clang/test/SemaHLSL/RootSignature-Validation.hlsl b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
new file mode 100644
index 0000000000000..5a7f5baf00619
--- /dev/null
+++ b/clang/test/SemaHLSL/RootSignature-Validation.hlsl
@@ -0,0 +1,33 @@
+// RUN: %clang_dxc -T cs_6_6 -E CSMain %s 2>&1 
+
+// expected-no-diagnostics
+
+
+#define ROOT_SIGNATURE \
+    "CBV(b3, space=1, visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(SRV(t0, space=0, numDescriptors=1), visibility=SHADER_VISIBILITY_ALL), " \
+    "DescriptorTable(Sampler(s0, numDescriptors=2), visibility=SHADER_VISIBILITY_VERTEX), " \
+    "DescriptorTable(UAV(u0, numDescriptors=unbounded), visibility=SHADER_VISIBILITY_ALL)"
+
+cbuffer CB : register(b3, space1) {
+  float a;
+}
+
+StructuredBuffer<int> In : register(t0, space0);
+RWStructuredBuffer<int> Out : register(u0);
+
+RWBuffer<float> UAV : register(u4294967294);
+
+RWBuffer<float> UAV1 : register(u2), UAV2 : register(u4);
+
+RWBuffer<float> UAV3 : register(space0);
+
+
+
+// Compute Shader for UAV testing
+[numthreads(8, 8, 1)]
+[RootSignature(ROOT_SIGNATURE)]
+void CSMain(uint id : SV_GroupID)
+{
+    Out[0] = a + id + In[0] + UAV[0] + UAV1[0] + UAV3[0];
+}
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
index 398dcbb8d1737..a52a04323514c 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILPostOptimizationValidation.h"
+#include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallString.h"
@@ -84,8 +85,60 @@ static void reportOverlappingBinding(Module &M, DXILResourceMap &DRM) {
   }
 }
 
+static void reportRegNotBound(Module &M, Twine Type,
+                              ResourceInfo::ResourceBinding Binding) {
+  SmallString<128> Message;
+  raw_svector_ostream OS(Message);
+  OS << "register " << Type << " (space=" << Binding.Space
+     << ", register=" << Binding.LowerBound << ")"
+     << " is not defined in Root Signature";
+  M.getContext().diagnose(DiagnosticInfoGeneric(Message));
+}
+
+static dxbc::ShaderVisibility
+tripleToVisibility(llvm::Triple::EnvironmentType ET) {
+  assert((ET == Triple::Pixel || ET == Triple::Vertex ||
+          ET == Triple::Geometry || ET == Triple::Hull ||
+          ET == Triple::Domain || ET == Triple::Mesh ||
+          ET == Triple::Compute) &&
+         "Invalid Triple to shader stage conversion");
+
+  switch (ET) {
+  case Triple::Pixel:
+    return dxbc::ShaderVisibility::Pixel;
+  case Triple::Vertex:
+    return dxbc::ShaderVisibility::Vertex;
+  case Triple::Geometry:
+    return dxbc::ShaderVisibility::Geometry;
+  case Triple::Hull:
+    return dxbc::ShaderVisibility::Hull;
+  case Triple::Domain:
+    return dxbc::ShaderVisibility::Domain;
+  case Triple::Mesh:
+    return dxbc::ShaderVisibility::Mesh;
+  case Triple::Compute:
+    return dxbc::ShaderVisibility::All;
+  default:
+    llvm_unreachable("Invalid triple to shader stage conversion");
+  }
+}
+
+std::optional<mcdxbc::RootSignatureDesc>
+getRootSignature(RootSignatureBindingInfo &RSBI,
+                 dxil::ModuleMetadataInfo &MMI) {
+  if (MMI.EntryPropertyVec.size() == 0)
+    return std::nullopt;
+  std::optional<mcdxbc::RootSignatureDesc> RootSigDesc =
+      RSBI.getDescForFunction(MMI.EntryPropertyVec[0].Entry);
+  if (!RootSigDesc)
+    return std::nullopt;
+  return RootSigDesc;
+}
+
 static void reportErrors(Module &M, DXILResourceMap &DRM,
-                         DXILResourceBindingInfo &DRBI) {
+                         DXILResourceBindingInfo &DRBI,
+                         RootSignatureBindingInfo &RSBI,
+                         dxil::ModuleMetadataInfo &MMI) {
   if (DRM.hasInvalidCounterDirection())
     reportInvalidDirection(M, DRM);
 
@@ -94,14 +147,83 @@ static void reportErrors(Module &M, DXILResourceMap &DRM,
 
   assert(!DRBI.hasImplicitBinding() && "implicit bindings should be handled in "
                                        "DXILResourceImplicitBinding pass");
+
+  if (auto RSD = getRootSignature(RSBI, MMI)) {
+
+    RootSignatureBindingValidation Validation;
+    Validation.addRsBindingInfo(*RSD, tripleToVisibility(MMI.ShaderProfile));
+
+    for (const ResourceInfo &CBuf : DRM.cbuffers()) {
+      ResourceInfo::ResourceBinding Binding = CBuf.getBinding();
+      if (!Validation.checkCRegBinding(Binding))
+        reportRegNotBound(M, "cbuffer", Binding);
+    }
+
+    for (const ResourceInfo &SRV : DRM.srvs()) {
+      ResourceInfo::ResourceBinding Binding = SRV.getBinding();
+      if (!Validation.checkTRegBinding(Binding))
+        reportRegNotBound(M, "srv", Binding);
+    }
+
+    for (const ResourceInfo &UAV : DRM.uavs()) {
+      ResourceInfo::ResourceBinding Binding = UAV.getBinding();
+      if (!Validation.checkURegBinding(Binding))
+        reportRegNotBound(M, "uav", Binding);
+    }
+
+    for (const ResourceInfo &Sampler : DRM.samplers()) {
+      ResourceInfo::ResourceBinding Binding = Sampler.getBinding();
+      if (!Validation.checkSamplerBinding(Binding))
+        reportRegNotBound(M, "sampler", Binding);
+    }
+  }
 }
 } // namespace
 
+void RootSignatureBindingValidation::addRsBindingInfo(
+    mcdxbc::RootSignatureDesc &RSD, dxbc::ShaderVisibility Visibility) {
+  for (size_t I = 0; I < RSD.ParametersContainer.size(); I++) {
+    const auto &[Type, Loc] =
+        RSD.ParametersContainer.getTypeAndLocForParameter(I);
+
+    const auto &Header = RSD.ParametersContainer.getHeader(I);
+    switch (Type) {
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+    case llvm::to_underlying(dxbc::RootParameterType::CBV): {
+      dxbc::RTS0::v2::RootDescriptor Desc =
+          RSD.ParametersContainer.getRootDescriptor(Loc);
+
+      if (Header.ShaderVisibility ==
+              llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+          Header.ShaderVisibility == llvm::to_underlying(Visibility))
+        addRange(Desc, Type);
+      break;
+    }
+    case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+      const mcdxbc::DescriptorTable &Table =
+          RSD.ParametersContainer.getDescriptorTable(Loc);
+
+      for (const dxbc::RTS0::v2::DescriptorRange &Range : Table.Ranges) {
+        if (Header.ShaderVisibility ==
+                llvm::to_underlying(dxbc::ShaderVisibility::All) ||
+            Header.ShaderVisibility == llvm::to_underlying(Visibility))
+          addRange(Range);
+      }
+      break;
+    }
+    }
+  }
+}
+
 PreservedAnalyses
 DXILPostOptimizationValidation::run(Module &M, ModuleAnalysisManager &MAM) {
   DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   DXILResourceBindingInfo &DRBI = MAM.getResult<DXILResourceBindingAnalysis>(M);
-  reportErrors(M, DRM, DRBI);
+  RootSignatureBindingInfo &RSBI = MAM.getResult<RootSignatureAnalysis>(M);
+  ModuleMetadataInfo &MMI = MAM.getResult<DXILMetadataAnalysis>(M);
+
+  reportErrors(M, DRM, DRBI, RSBI, MMI);
   return PreservedAnalyses::all();
 }
 
@@ -113,7 +235,12 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     DXILResourceBindingInfo &DRBI =
         getAnalysis<DXILResourceBindingWrapperPass>().getBindingInfo();
-    reportErrors(M, DRM, DRBI);
+    RootSignatureBindingInfo &RSBI =
+        getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo();
+    dxil::ModuleMetadataInfo &MMI =
+        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+
+    reportErrors(M, DRM, DRBI, RSBI, MMI);
     return false;
   }
   StringRef getPassName() const override {
@@ -125,10 +252,13 @@ class DXILPostOptimizationValidationLegacy : public ModulePass {
   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
     AU.addRequired<DXILResourceWrapperPass>();
     AU.addRequired<DXILResourceBindingWrapperPass>();
+    AU.addRequired<RootSignatureAnalysisWrapper>();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<DXILResourceWrapperPass>();
     AU.addPreserved<DXILResourceBindingWrapperPass>();
     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
+    AU.addPreserved<RootSignatureAnalysisWrapper>();
   }
 };
 char DXILPostOptimizationValidationLegacy::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
index cb5e624514272..0fa0285425d7e 100644
--- a/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
+++ b/llvm/lib/Target/DirectX/DXILPostOptimizationValidation.h
@@ -14,10 +14,129 @@
 #ifndef LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 #define LLVM_LIB_TARGET_DIRECTX_DXILPOSTOPTIMIZATIONVALIDATION_H
 
+#include "DXILRootSignature.h"
+#include "llvm/ADT/IntervalMap.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/PassManager.h"
 
 namespace llvm {
 
+static uint64_t combineUint32ToUint64(uint32_t High, uint32_t Low) {
+  return (static_cast<uint64_t>(High) << 32) | Low;
+}
+
+class RootSignatureBindingValidation {
+  using MapT =
+      llvm::IntervalMap<uint64_t, dxil::ResourceInfo::ResourceBinding,
+                        sizeof(llvm::dxil::ResourceInfo::ResourceBinding),
+                        llvm::IntervalMapInfo<uint64_t>>;
+
+private:
+  MapT::Allocator Allocator;
+  MapT CRegBindingsMap;
+  MapT TRegBindingsMap;
+  MapT URegBindingsMap;
+  MapT SamplersBindingsMap;
+
+  void addRange(const dxbc::RTS0::v2::RootDescriptor &Desc, uint32_t Type) {
+    assert((Type == llvm::to_underlying(dxbc::RootParameterType::CBV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::SRV) ||
+            Type == llvm::to_underlying(dxbc::RootParameterType::UAV)) &&
+           "Invalid Type in add Range Method");
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Desc.ShaderRegister;
+    Binding.Space = Desc.RegisterSpace;
+    Binding.Size = 1;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
+    switch (Type) {
+
+    case llvm::to_underlying(dxbc::RootParameterType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::RootParameterType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::RootParameterType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    }
+  }
+
+  void addRange(const dxbc::RTS0::v2::DescriptorRange &Range) {
+
+    llvm::dxil::ResourceInfo::ResourceBinding Binding;
+    Binding.LowerBound = Range.BaseShaderRegister;
+    Binding.Space = Range.RegisterSpace;
+    Binding.Size = Range.NumDescriptors;
+
+    uint64_t LowRange =
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound);
+    uint64_t HighRange = combineUint32ToUint64(
+        Binding.Space, Binding.LowerBound + Binding.Size - 1);
+
+    assert(LowRange <= HighRange && "Invalid range configuration");
+
+    switch (Range.RangeType) {
+    case llvm::to_underlying(dxbc::DescriptorRangeType::CBV):
+      CRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::SRV):
+      TRegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::UAV):
+      URegBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    case llvm::to_underlying(dxbc::DescriptorRangeType::Sampler):
+      SamplersBindingsMap.insert(LowRange, HighRange, Binding);
+      break;
+    }
+  }
+
+public:
+  RootSignatureBindingValidation()
+      : Allocator(), CRegBindingsMap(Allocator), TRegBindingsMap(Allocator),
+        URegBindingsMap(Allocator), SamplersBindingsMap(Allocator) {}
+
+  void addRsBindingInfo(mcdxbc::RootSignatureDesc &RSD,
+                        dxbc::ShaderVisibility Visibility);
+
+  bool checkCRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return CRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkTRegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return TRegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkURegBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return URegBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+
+  bool checkSamplerBinding(dxil::ResourceInfo::ResourceBinding Binding) {
+    return SamplersBindingsMap.overlaps(
+        combineUint32ToUint64(Binding.Space, Binding.LowerBound),
+        combineUint32ToUint64(Binding.Space,
+                              Binding.LowerBound + Binding.Size - 1));
+  }
+};
+
 class DXILPostOptimizationValidation
     : public PassInfoMixin<DXILPostOptimizationValidation> {
 public:
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
index 9d89dbdd9107b..053721de1eb1f 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-AllValidFlagCombinationsV1.ll
@@ -13,7 +13,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 1 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"Sampler", i32 0, i32 1, i32 0, i32 -1, i32 1 }
+!6 = !{ !"Sampler", i32 1, i32 1, i32 0, i32 -1, i32 1 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 3 }
 
 
@@ -33,7 +33,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:             RangesOffset:    44
 ; DXC-NEXT:             Ranges:
 ; DXC-NEXT:               - RangeType:       3
-; DXC-NEXT:                 NumDescriptors:  0
+; DXC-NEXT:                 NumDescriptors:  1
 ; DXC-NEXT:                 BaseShaderRegister: 1
 ; DXC-NEXT:                 RegisterSpace:   0
 ; DXC-NEXT:                 OffsetInDescriptorsFromTableStart: 4294967295
diff --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
index b516d66180247..8e9b4b43b11a6 100644
--- a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable.ll
@@ -16,7 +16,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 !2 = !{ ptr @main, !3, i32 2 } ; function, root signature
 !3 = !{ !5 } ; list of root signature elements
 !5 = !{ !"DescriptorTable", i32 0, !6, !7 }
-!6 = !{ !"SRV", i32 0, i32 1, i32 0, i32 -1, i32 4 }
+!6 = !{ !"SRV", i32 1, i32 1, i32 0, i32 -1, i32 4 }
 !7 = !{ !"UAV", i32 5, i32 1, i32 10, i32 5, i32 2 }
 
 ; DXC:  - Name:            RTS0
@@ -35,7 +35,7 @@ attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
 ; DXC-NEXT:            RangesOffset:    44
 ; DXC-NEXT:            Ranges:
 ; DXC-NEXT:              - RangeType:       0
-; DXC-NEXT:                NumDescriptors:  0
+; DXC-NEXT:                NumDescriptors:  1
 ; DXC-NEXT:                BaseShaderRegister: 1
 ; DXC-NEXT:                RegisterSpace:   0
 ; DXC-NEXT:                OffsetInDescriptorsFromTableStart: 4294967295
diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 2b29fd30a7a56..8d75249dc6ecb 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -31,6 +31,7 @@
 ; CHECK-NEXT:   DXIL Module Metadata analysis
 ; CHECK-NEXT:   DXIL Shader Flag Analysis
 ; CHECK-NEXT:   DXIL Translate Metadata
+; CHECK-NEXT:   DXIL Root Signature Analysis
 ; CHECK-NEXT:   DXIL Post Optimization Validation
 ; CHECK-NEXT:   DXIL Op Lowering
 ; CHECK-NEXT:   DXIL Prepare Module

@bogner
Copy link
Contributor

bogner commented Aug 27, 2025

I don't think we need this new BusyBindingInfo structure at all, and this can all be quite a bit simpler. Consider a function like so on BindingInfoBuilder:

  bool isBound(dxil::ResourceClass RC, uint32_t Space, uint32_t LowerBound,
               uint32_t UpperBound) const {
    auto It =
        llvm::upper_bound(Bindings, Binding{RC, Space, LowerBound, 0, nullptr});
    if (It == Bindings.begin())
      return false;
    --It;
    return It->RC == RC && It->Space == Space && It->LowerBound <= LowerBound &&
           It->UpperBound >= UpperBound;
  }

With this, the only thing we need to do in DXILPostOptimizationValidation is loop over the resources in DRM and check if they're bound:

  for (const ResourceInfo &RI : DRM) {
    const ResourceInfo::ResourceBinding &Binding = RI.getBinding();
    ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
    if (!Builder.isBound(RC, Binding.Space, Binding.LowerBound,
                         Binding.LowerBound + Binding.Size - 1))
      reportRegNotBound(M, RC, Binding);
  }

We discussed that it's a little awkward to keep using the Builder for this after we've built the BindingInfo freelist, but even if we want to do something to be more explicit about lifetime and ownership, a class or struct that simply has the sorted Binding vector and the isBound function would be sufficient here. That said, I think the simplicity of hanging this off of the builder has something to be said for it.

See bogner@b5b7483 for a demonstration of what I'm saying above.

@joaosaffran joaosaffran requested review from bogner and inbelic August 28, 2025 17:50
Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small comments about the getBoundRegs function, but otherwise this LGTM


LLVM_ABI BoundRegs getBoundRegs() {
assert(std::is_sorted(Bindings.begin(), Bindings.end()) &&
"Bindings must be sorted");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::is_sorted already makes it clear what we're asserting. Might be better to explain why with something like "getBoundRegs should only be called after calculateBindingInfo"

[&HasOverlap](auto, auto) { HasOverlap = true; });
}

LLVM_ABI BoundRegs getBoundRegs() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think takeBoundRegs would be a clearer name, since this modifies/invalidates the BindingInfoBuilder

@joaosaffran joaosaffran merged commit 36ebd17 into llvm:main Aug 28, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:DirectX clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants