Skip to content

[DirectX] Error handling improve in root signature metadata Parser #149232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

joaosaffran
Copy link
Contributor

This PR addresses #144465 (comment). Using joinErrors and llvm:Error instead of boolean values.

@llvmbot llvmbot added backend:DirectX HLSL HLSL Language Support labels Jul 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 17, 2025

@llvm/pr-subscribers-hlsl

Author: None (joaosaffran)

Changes

This PR addresses #144465 (comment). Using joinErrors and llvm:Error instead of boolean values.


Patch is 39.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149232.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h (+86-21)
  • (modified) llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp (+197-134)
  • (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+14-10)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll (+1-1)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index cd5966db42b5f..b3705a2132021 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -26,6 +26,73 @@ class Metadata;
 namespace hlsl {
 namespace rootsig {
 
+template <typename T>
+class RootSignatureValidationError
+    : public ErrorInfo<RootSignatureValidationError<T>> {
+public:
+  static char ID;
+  std::string ParamName;
+  T Value;
+
+  RootSignatureValidationError(StringRef ParamName, T Value)
+      : ParamName(ParamName.str()), Value(Value) {}
+
+  void log(raw_ostream &OS) const override {
+    OS << "Invalid value for " << ParamName << ": " << Value;
+  }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
+class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
+public:
+  static char ID;
+  std::string Message;
+
+  GenericRSMetadataError(Twine Message) : Message(Message.str()) {}
+
+  void log(raw_ostream &OS) const override { OS << Message; }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
+class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
+public:
+  static char ID;
+  std::string ElementName;
+
+  InvalidRSMetadataFormat(StringRef ElementName)
+      : ElementName(ElementName.str()) {}
+
+  void log(raw_ostream &OS) const override {
+    OS << "Invalid format for  " << ElementName;
+  }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
+class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
+public:
+  static char ID;
+  std::string ParamName;
+
+  InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName.str()) {}
+
+  void log(raw_ostream &OS) const override {
+    OS << "Invalid value for " << ParamName;
+  }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
 class MetadataBuilder {
 public:
   MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
@@ -67,29 +134,27 @@ class MetadataParser {
   MetadataParser(MDNode *Root) : Root(Root) {}
 
   /// Iterates through root signature and converts them into MapT
-  LLVM_ABI bool ParseRootSignature(LLVMContext *Ctx,
-                                   mcdxbc::RootSignatureDesc &RSD);
+  LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
+  ParseRootSignature(uint32_t Version);
 
 private:
-  bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                      MDNode *RootFlagNode);
-  bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                          MDNode *RootConstantNode);
-  bool parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                            MDNode *RootDescriptorNode,
-                            RootSignatureElementKind ElementKind);
-  bool parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table,
-                            MDNode *RangeDescriptorNode);
-  bool parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                            MDNode *DescriptorTableNode);
-  bool parseRootSignatureElement(LLVMContext *Ctx,
-                                 mcdxbc::RootSignatureDesc &RSD,
-                                 MDNode *Element);
-  bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                          MDNode *StaticSamplerNode);
-
-  bool validateRootSignature(LLVMContext *Ctx,
-                             const llvm::mcdxbc::RootSignatureDesc &RSD);
+  llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+                             MDNode *RootFlagNode);
+  llvm::Error parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+                                 MDNode *RootConstantNode);
+  llvm::Error parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+                                   MDNode *RootDescriptorNode,
+                                   RootSignatureElementKind ElementKind);
+  llvm::Error parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+                                   MDNode *RangeDescriptorNode);
+  llvm::Error parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+                                   MDNode *DescriptorTableNode);
+  llvm::Error parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+                                        MDNode *Element);
+  llvm::Error parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+                                 MDNode *StaticSamplerNode);
+
+  llvm::Error validateRootSignature(const llvm::mcdxbc::RootSignatureDesc &RSD);
 
   MDNode *Root;
 };
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f59349ae029..41c23ecb692ea 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -13,7 +13,6 @@
 
 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
-#include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/ScopedPrinter.h"
@@ -22,7 +21,12 @@ namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
-static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+char GenericRSMetadataError::ID;
+char InvalidRSMetadataFormat::ID;
+char InvalidRSMetadataValue::ID;
+template <typename T> char RootSignatureValidationError<T>::ID;
+
+inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
                                                  unsigned int OpId) {
   if (auto *CI =
           mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
@@ -30,14 +34,14 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
   return std::nullopt;
 }
 
-static std::optional<float> extractMdFloatValue(MDNode *Node,
+inline std::optional<float> extractMdFloatValue(MDNode *Node,
                                                 unsigned int OpId) {
   if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
     return CI->getValueAPF().convertToFloat();
   return std::nullopt;
 }
 
-static std::optional<StringRef> extractMdStringValue(MDNode *Node,
+inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
                                                      unsigned int OpId) {
   MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
   if (NodeText == nullptr)
@@ -45,19 +49,6 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
   return NodeText->getString();
 }
 
-static bool reportError(LLVMContext *Ctx, Twine Message,
-                        DiagnosticSeverity Severity = DS_Error) {
-  Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
-  return true;
-}
-
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
-                             uint32_t Value) {
-  Ctx->diagnose(DiagnosticInfoGeneric(
-      "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
-  return true;
-}
-
 static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
     {"CBV", dxil::ResourceClass::CBuffer},
     {"SRV", dxil::ResourceClass::SRV},
@@ -227,27 +218,23 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
   return MDNode::get(Ctx, Operands);
 }
 
-bool MetadataParser::parseRootFlags(LLVMContext *Ctx,
-                                    mcdxbc::RootSignatureDesc &RSD,
-                                    MDNode *RootFlagNode) {
-
+llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+                                           MDNode *RootFlagNode) {
   if (RootFlagNode->getNumOperands() != 2)
-    return reportError(Ctx, "Invalid format for RootFlag Element");
+    return make_error<InvalidRSMetadataFormat>("RootFlag Element");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
     RSD.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for RootFlag");
+    return make_error<InvalidRSMetadataValue>("RootFlag");
 
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
-                                        mcdxbc::RootSignatureDesc &RSD,
-                                        MDNode *RootConstantNode) {
-
+llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+                                               MDNode *RootConstantNode) {
   if (RootConstantNode->getNumOperands() != 5)
-    return reportError(Ctx, "Invalid format for RootConstants Element");
+    return make_error<InvalidRSMetadataFormat>("RootConstants Element");
 
   dxbc::RTS0::v1::RootParameterHeader Header;
   // The parameter offset doesn't matter here - we recalculate it during
@@ -258,39 +245,40 @@ bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
 
   dxbc::RTS0::v1::RootConstants Constants;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
     Constants.ShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
+    return make_error<InvalidRSMetadataValue>("ShaderRegister");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
     Constants.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return make_error<InvalidRSMetadataValue>("RegisterSpace");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
     Constants.Num32BitValues = *Val;
   else
-    return reportError(Ctx, "Invalid value for Num32BitValues");
+    return make_error<InvalidRSMetadataValue>("Num32BitValues");
 
   RSD.ParametersContainer.addParameter(Header, Constants);
 
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseRootDescriptors(
-    LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-    MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) {
+llvm::Error
+MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+                                     MDNode *RootDescriptorNode,
+                                     RootSignatureElementKind ElementKind) {
   assert(ElementKind == RootSignatureElementKind::SRV ||
          ElementKind == RootSignatureElementKind::UAV ||
          ElementKind == RootSignatureElementKind::CBV &&
-             "parseRootDescriptors should only be called with RootDescriptor "
+             "parseRootDescriptors should only be called with RootDescriptor"
              "element kind.");
   if (RootDescriptorNode->getNumOperands() != 5)
-    return reportError(Ctx, "Invalid format for Root Descriptor Element");
+    return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
 
   dxbc::RTS0::v1::RootParameterHeader Header;
   switch (ElementKind) {
@@ -311,40 +299,38 @@ bool MetadataParser::parseRootDescriptors(
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
 
   dxbc::RTS0::v2::RootDescriptor Descriptor;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
     Descriptor.ShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
+    return make_error<InvalidRSMetadataValue>("ShaderRegister");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
     Descriptor.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return make_error<InvalidRSMetadataValue>("RegisterSpace");
 
   if (RSD.Version == 1) {
     RSD.ParametersContainer.addParameter(Header, Descriptor);
-    return false;
+    return llvm::Error::success();
   }
   assert(RSD.Version > 1);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
     Descriptor.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+    return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
 
   RSD.ParametersContainer.addParameter(Header, Descriptor);
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
-                                          mcdxbc::DescriptorTable &Table,
-                                          MDNode *RangeDescriptorNode) {
-
+llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+                                                 MDNode *RangeDescriptorNode) {
   if (RangeDescriptorNode->getNumOperands() != 6)
-    return reportError(Ctx, "Invalid format for Descriptor Range");
+    return make_error<InvalidRSMetadataFormat>("Descriptor Range");
 
   dxbc::RTS0::v2::DescriptorRange Range;
 
@@ -352,7 +338,7 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
       extractMdStringValue(RangeDescriptorNode, 0);
 
   if (!ElementText.has_value())
-    return reportError(Ctx, "Descriptor Range, first element is not a string.");
+    return make_error<InvalidRSMetadataFormat>("Descriptor Range");
 
   Range.RangeType =
       StringSwitch<uint32_t>(*ElementText)
@@ -364,50 +350,50 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
           .Default(~0U);
 
   if (Range.RangeType == ~0U)
-    return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+    return make_error<GenericRSMetadataError>("Invalid Descriptor Range type:" +
+                                              *ElementText);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
     Range.NumDescriptors = *Val;
   else
-    return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+    return make_error<GenericRSMetadataError>("Number of Descriptor in Range");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
     Range.BaseShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for BaseShaderRegister");
+    return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
     Range.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return make_error<InvalidRSMetadataValue>("RegisterSpace");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
     Range.OffsetInDescriptorsFromTableStart = *Val;
   else
-    return reportError(Ctx,
-                       "Invalid value for OffsetInDescriptorsFromTableStart");
+    return make_error<InvalidRSMetadataValue>(
+        "OffsetInDescriptorsFromTableStart");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
     Range.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+    return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
 
   Table.Ranges.push_back(Range);
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
-                                          mcdxbc::RootSignatureDesc &RSD,
-                                          MDNode *DescriptorTableNode) {
+llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+                                                 MDNode *DescriptorTableNode) {
   const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
   if (NumOperands < 2)
-    return reportError(Ctx, "Invalid format for Descriptor Table");
+    return make_error<InvalidRSMetadataFormat>("Descriptor Table");
 
   dxbc::RTS0::v1::RootParameterHeader Header;
   if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
 
   mcdxbc::DescriptorTable Table;
   Header.ParameterType =
@@ -416,98 +402,98 @@ bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
   for (unsigned int I = 2; I < NumOperands; I++) {
     MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
     if (Element == nullptr)
-      return reportError(Ctx, "Missing Root Element Metadata Node.");
+      return make_error<GenericRSMetadataError>(
+          "Missing Root Element Metadata Node.");
 
-    if (parseDescriptorRange(Ctx, Table, Element))
-      return true;
+    if (auto Err = parseDescriptorRange(Table, Element))
+      return Err;
   }
 
   RSD.ParametersContainer.addParameter(Header, Table);
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseStaticSampler(LLVMContext *Ctx,
-                                        mcdxbc::RootSignatureDesc &RSD,
-                                        MDNode *StaticSamplerNode) {
+llvm::Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+                                               MDNode *StaticSamplerNode) {
   if (StaticSamplerNode->getNumOperands() != 14)
-    return reportError(Ctx, "Invalid format for Static Sampler");
+    return make_error<InvalidRSMetadataFormat>("Static Sampler");
 
   dxbc::RTS0::v1::StaticSampler Sampler;
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
     Sampler.Filter = *Val;
   else
-    return reportError(Ctx, "Invalid value for Filter");
+    return make_error<InvalidRSMetadataValue>("Filter");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
     Sampler.AddressU = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressU");
+    return make_error<InvalidRSMetadataValue>("AddressU");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
     Sampler.AddressV = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressV");
+    return make_error<InvalidRSMetadataValue>("AddressV");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
     Sampler.AddressW = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressW");
+    return make_error<InvalidRSMetadataValue>("AddressW");
 
   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
     Sampler.MipLODBias = *Val;
   else
-    return reportError(Ctx, "Invalid value for MipLODBias");
+    return make_error<InvalidRSMetadataValue>("MipLODBias");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
     Sampler.MaxAnisotropy = *Val;
   else
-    return reportError(Ctx, "Invalid value for MaxAnisotropy");
+    return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
     Sampler.ComparisonFunc = *Val;
   else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+    return make_error<InvalidRSMetadataValue>("ComparisonFunc");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
     Sampler.BorderColor = *Val;
   else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+    return make_error<InvalidRSMetadataValue>("ComparisonFunc");
 
   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
     Sampler.MinLOD = *Val;
   else
-    return reportError(Ctx, "Invalid value for MinLOD");
+    return make_error<InvalidRSMetadataValue>("MinLOD");
 
   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
     Sampler.MaxLOD = *Val;
   else
-    return reportError(Ctx, "Invalid value for MaxLOD");
+    return make_error<InvalidRSMetadataValue>("MaxLOD");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
     Sa...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 17, 2025

@llvm/pr-subscribers-backend-directx

Author: None (joaosaffran)

Changes

This PR addresses #144465 (comment). Using joinErrors and llvm:Error instead of boolean values.


Patch is 39.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149232.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h (+86-21)
  • (modified) llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp (+197-134)
  • (modified) llvm/lib/Target/DirectX/DXILRootSignature.cpp (+14-10)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-DescriptorTable-Invalid-RangeType.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootDescriptor-Invalid-RegisterKind.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MaxLod.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLod.ll (+1-1)
  • (modified) llvm/test/CodeGen/DirectX/ContainerData/RootSignature-StaticSamplers-Invalid-MinLopBias.ll (+1-1)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index cd5966db42b5f..b3705a2132021 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -26,6 +26,73 @@ class Metadata;
 namespace hlsl {
 namespace rootsig {
 
+template <typename T>
+class RootSignatureValidationError
+    : public ErrorInfo<RootSignatureValidationError<T>> {
+public:
+  static char ID;
+  std::string ParamName;
+  T Value;
+
+  RootSignatureValidationError(StringRef ParamName, T Value)
+      : ParamName(ParamName.str()), Value(Value) {}
+
+  void log(raw_ostream &OS) const override {
+    OS << "Invalid value for " << ParamName << ": " << Value;
+  }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
+class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
+public:
+  static char ID;
+  std::string Message;
+
+  GenericRSMetadataError(Twine Message) : Message(Message.str()) {}
+
+  void log(raw_ostream &OS) const override { OS << Message; }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
+class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
+public:
+  static char ID;
+  std::string ElementName;
+
+  InvalidRSMetadataFormat(StringRef ElementName)
+      : ElementName(ElementName.str()) {}
+
+  void log(raw_ostream &OS) const override {
+    OS << "Invalid format for  " << ElementName;
+  }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
+class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
+public:
+  static char ID;
+  std::string ParamName;
+
+  InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName.str()) {}
+
+  void log(raw_ostream &OS) const override {
+    OS << "Invalid value for " << ParamName;
+  }
+
+  std::error_code convertToErrorCode() const override {
+    return llvm::inconvertibleErrorCode();
+  }
+};
+
 class MetadataBuilder {
 public:
   MetadataBuilder(llvm::LLVMContext &Ctx, ArrayRef<RootElement> Elements)
@@ -67,29 +134,27 @@ class MetadataParser {
   MetadataParser(MDNode *Root) : Root(Root) {}
 
   /// Iterates through root signature and converts them into MapT
-  LLVM_ABI bool ParseRootSignature(LLVMContext *Ctx,
-                                   mcdxbc::RootSignatureDesc &RSD);
+  LLVM_ABI llvm::Expected<llvm::mcdxbc::RootSignatureDesc>
+  ParseRootSignature(uint32_t Version);
 
 private:
-  bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                      MDNode *RootFlagNode);
-  bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                          MDNode *RootConstantNode);
-  bool parseRootDescriptors(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                            MDNode *RootDescriptorNode,
-                            RootSignatureElementKind ElementKind);
-  bool parseDescriptorRange(LLVMContext *Ctx, mcdxbc::DescriptorTable &Table,
-                            MDNode *RangeDescriptorNode);
-  bool parseDescriptorTable(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                            MDNode *DescriptorTableNode);
-  bool parseRootSignatureElement(LLVMContext *Ctx,
-                                 mcdxbc::RootSignatureDesc &RSD,
-                                 MDNode *Element);
-  bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-                          MDNode *StaticSamplerNode);
-
-  bool validateRootSignature(LLVMContext *Ctx,
-                             const llvm::mcdxbc::RootSignatureDesc &RSD);
+  llvm::Error parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+                             MDNode *RootFlagNode);
+  llvm::Error parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+                                 MDNode *RootConstantNode);
+  llvm::Error parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+                                   MDNode *RootDescriptorNode,
+                                   RootSignatureElementKind ElementKind);
+  llvm::Error parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+                                   MDNode *RangeDescriptorNode);
+  llvm::Error parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+                                   MDNode *DescriptorTableNode);
+  llvm::Error parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+                                        MDNode *Element);
+  llvm::Error parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+                                 MDNode *StaticSamplerNode);
+
+  llvm::Error validateRootSignature(const llvm::mcdxbc::RootSignatureDesc &RSD);
 
   MDNode *Root;
 };
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f59349ae029..41c23ecb692ea 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -13,7 +13,6 @@
 
 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
 #include "llvm/Frontend/HLSL/RootSignatureValidations.h"
-#include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/Support/ScopedPrinter.h"
@@ -22,7 +21,12 @@ namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
-static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
+char GenericRSMetadataError::ID;
+char InvalidRSMetadataFormat::ID;
+char InvalidRSMetadataValue::ID;
+template <typename T> char RootSignatureValidationError<T>::ID;
+
+inline std::optional<uint32_t> extractMdIntValue(MDNode *Node,
                                                  unsigned int OpId) {
   if (auto *CI =
           mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpId).get()))
@@ -30,14 +34,14 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
   return std::nullopt;
 }
 
-static std::optional<float> extractMdFloatValue(MDNode *Node,
+inline std::optional<float> extractMdFloatValue(MDNode *Node,
                                                 unsigned int OpId) {
   if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
     return CI->getValueAPF().convertToFloat();
   return std::nullopt;
 }
 
-static std::optional<StringRef> extractMdStringValue(MDNode *Node,
+inline std::optional<StringRef> extractMdStringValue(MDNode *Node,
                                                      unsigned int OpId) {
   MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
   if (NodeText == nullptr)
@@ -45,19 +49,6 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
   return NodeText->getString();
 }
 
-static bool reportError(LLVMContext *Ctx, Twine Message,
-                        DiagnosticSeverity Severity = DS_Error) {
-  Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
-  return true;
-}
-
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
-                             uint32_t Value) {
-  Ctx->diagnose(DiagnosticInfoGeneric(
-      "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
-  return true;
-}
-
 static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
     {"CBV", dxil::ResourceClass::CBuffer},
     {"SRV", dxil::ResourceClass::SRV},
@@ -227,27 +218,23 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
   return MDNode::get(Ctx, Operands);
 }
 
-bool MetadataParser::parseRootFlags(LLVMContext *Ctx,
-                                    mcdxbc::RootSignatureDesc &RSD,
-                                    MDNode *RootFlagNode) {
-
+llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+                                           MDNode *RootFlagNode) {
   if (RootFlagNode->getNumOperands() != 2)
-    return reportError(Ctx, "Invalid format for RootFlag Element");
+    return make_error<InvalidRSMetadataFormat>("RootFlag Element");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
     RSD.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for RootFlag");
+    return make_error<InvalidRSMetadataValue>("RootFlag");
 
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
-                                        mcdxbc::RootSignatureDesc &RSD,
-                                        MDNode *RootConstantNode) {
-
+llvm::Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+                                               MDNode *RootConstantNode) {
   if (RootConstantNode->getNumOperands() != 5)
-    return reportError(Ctx, "Invalid format for RootConstants Element");
+    return make_error<InvalidRSMetadataFormat>("RootConstants Element");
 
   dxbc::RTS0::v1::RootParameterHeader Header;
   // The parameter offset doesn't matter here - we recalculate it during
@@ -258,39 +245,40 @@ bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
 
   dxbc::RTS0::v1::RootConstants Constants;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
     Constants.ShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
+    return make_error<InvalidRSMetadataValue>("ShaderRegister");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
     Constants.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return make_error<InvalidRSMetadataValue>("RegisterSpace");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
     Constants.Num32BitValues = *Val;
   else
-    return reportError(Ctx, "Invalid value for Num32BitValues");
+    return make_error<InvalidRSMetadataValue>("Num32BitValues");
 
   RSD.ParametersContainer.addParameter(Header, Constants);
 
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseRootDescriptors(
-    LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
-    MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) {
+llvm::Error
+MetadataParser::parseRootDescriptors(mcdxbc::RootSignatureDesc &RSD,
+                                     MDNode *RootDescriptorNode,
+                                     RootSignatureElementKind ElementKind) {
   assert(ElementKind == RootSignatureElementKind::SRV ||
          ElementKind == RootSignatureElementKind::UAV ||
          ElementKind == RootSignatureElementKind::CBV &&
-             "parseRootDescriptors should only be called with RootDescriptor "
+             "parseRootDescriptors should only be called with RootDescriptor"
              "element kind.");
   if (RootDescriptorNode->getNumOperands() != 5)
-    return reportError(Ctx, "Invalid format for Root Descriptor Element");
+    return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
 
   dxbc::RTS0::v1::RootParameterHeader Header;
   switch (ElementKind) {
@@ -311,40 +299,38 @@ bool MetadataParser::parseRootDescriptors(
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
 
   dxbc::RTS0::v2::RootDescriptor Descriptor;
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
     Descriptor.ShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderRegister");
+    return make_error<InvalidRSMetadataValue>("ShaderRegister");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
     Descriptor.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return make_error<InvalidRSMetadataValue>("RegisterSpace");
 
   if (RSD.Version == 1) {
     RSD.ParametersContainer.addParameter(Header, Descriptor);
-    return false;
+    return llvm::Error::success();
   }
   assert(RSD.Version > 1);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
     Descriptor.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+    return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
 
   RSD.ParametersContainer.addParameter(Header, Descriptor);
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
-                                          mcdxbc::DescriptorTable &Table,
-                                          MDNode *RangeDescriptorNode) {
-
+llvm::Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+                                                 MDNode *RangeDescriptorNode) {
   if (RangeDescriptorNode->getNumOperands() != 6)
-    return reportError(Ctx, "Invalid format for Descriptor Range");
+    return make_error<InvalidRSMetadataFormat>("Descriptor Range");
 
   dxbc::RTS0::v2::DescriptorRange Range;
 
@@ -352,7 +338,7 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
       extractMdStringValue(RangeDescriptorNode, 0);
 
   if (!ElementText.has_value())
-    return reportError(Ctx, "Descriptor Range, first element is not a string.");
+    return make_error<InvalidRSMetadataFormat>("Descriptor Range");
 
   Range.RangeType =
       StringSwitch<uint32_t>(*ElementText)
@@ -364,50 +350,50 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
           .Default(~0U);
 
   if (Range.RangeType == ~0U)
-    return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+    return make_error<GenericRSMetadataError>("Invalid Descriptor Range type:" +
+                                              *ElementText);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
     Range.NumDescriptors = *Val;
   else
-    return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+    return make_error<GenericRSMetadataError>("Number of Descriptor in Range");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
     Range.BaseShaderRegister = *Val;
   else
-    return reportError(Ctx, "Invalid value for BaseShaderRegister");
+    return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
     Range.RegisterSpace = *Val;
   else
-    return reportError(Ctx, "Invalid value for RegisterSpace");
+    return make_error<InvalidRSMetadataValue>("RegisterSpace");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
     Range.OffsetInDescriptorsFromTableStart = *Val;
   else
-    return reportError(Ctx,
-                       "Invalid value for OffsetInDescriptorsFromTableStart");
+    return make_error<InvalidRSMetadataValue>(
+        "OffsetInDescriptorsFromTableStart");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
     Range.Flags = *Val;
   else
-    return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+    return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
 
   Table.Ranges.push_back(Range);
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
-                                          mcdxbc::RootSignatureDesc &RSD,
-                                          MDNode *DescriptorTableNode) {
+llvm::Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+                                                 MDNode *DescriptorTableNode) {
   const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
   if (NumOperands < 2)
-    return reportError(Ctx, "Invalid format for Descriptor Table");
+    return make_error<InvalidRSMetadataFormat>("Descriptor Table");
 
   dxbc::RTS0::v1::RootParameterHeader Header;
   if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
     Header.ShaderVisibility = *Val;
   else
-    return reportError(Ctx, "Invalid value for ShaderVisibility");
+    return make_error<InvalidRSMetadataValue>("ShaderVisibility");
 
   mcdxbc::DescriptorTable Table;
   Header.ParameterType =
@@ -416,98 +402,98 @@ bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
   for (unsigned int I = 2; I < NumOperands; I++) {
     MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
     if (Element == nullptr)
-      return reportError(Ctx, "Missing Root Element Metadata Node.");
+      return make_error<GenericRSMetadataError>(
+          "Missing Root Element Metadata Node.");
 
-    if (parseDescriptorRange(Ctx, Table, Element))
-      return true;
+    if (auto Err = parseDescriptorRange(Table, Element))
+      return Err;
   }
 
   RSD.ParametersContainer.addParameter(Header, Table);
-  return false;
+  return llvm::Error::success();
 }
 
-bool MetadataParser::parseStaticSampler(LLVMContext *Ctx,
-                                        mcdxbc::RootSignatureDesc &RSD,
-                                        MDNode *StaticSamplerNode) {
+llvm::Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+                                               MDNode *StaticSamplerNode) {
   if (StaticSamplerNode->getNumOperands() != 14)
-    return reportError(Ctx, "Invalid format for Static Sampler");
+    return make_error<InvalidRSMetadataFormat>("Static Sampler");
 
   dxbc::RTS0::v1::StaticSampler Sampler;
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
     Sampler.Filter = *Val;
   else
-    return reportError(Ctx, "Invalid value for Filter");
+    return make_error<InvalidRSMetadataValue>("Filter");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
     Sampler.AddressU = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressU");
+    return make_error<InvalidRSMetadataValue>("AddressU");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
     Sampler.AddressV = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressV");
+    return make_error<InvalidRSMetadataValue>("AddressV");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
     Sampler.AddressW = *Val;
   else
-    return reportError(Ctx, "Invalid value for AddressW");
+    return make_error<InvalidRSMetadataValue>("AddressW");
 
   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
     Sampler.MipLODBias = *Val;
   else
-    return reportError(Ctx, "Invalid value for MipLODBias");
+    return make_error<InvalidRSMetadataValue>("MipLODBias");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
     Sampler.MaxAnisotropy = *Val;
   else
-    return reportError(Ctx, "Invalid value for MaxAnisotropy");
+    return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
     Sampler.ComparisonFunc = *Val;
   else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+    return make_error<InvalidRSMetadataValue>("ComparisonFunc");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
     Sampler.BorderColor = *Val;
   else
-    return reportError(Ctx, "Invalid value for ComparisonFunc ");
+    return make_error<InvalidRSMetadataValue>("ComparisonFunc");
 
   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
     Sampler.MinLOD = *Val;
   else
-    return reportError(Ctx, "Invalid value for MinLOD");
+    return make_error<InvalidRSMetadataValue>("MinLOD");
 
   if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
     Sampler.MaxLOD = *Val;
   else
-    return reportError(Ctx, "Invalid value for MaxLOD");
+    return make_error<InvalidRSMetadataValue>("MaxLOD");
 
   if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
     Sa...
[truncated]

@@ -3,7 +3,7 @@

target triple = "dxil-unknown-shadermodel6.0-compute"

; CHECK: error: Invalid value for MaxLOD: 0
; CHECK: error: Invalid value for MaxLOD: nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would this change to nan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was always supposed to be nan, previously the error message converted nan into a string. Now, with the error classes, I can use errs() output stream to format those values when showing errors to the user. So, this changed occurred due to the change into custom Error Classes

…ndling/improve-error-handling-in-rs-metadata-parser
@joaosaffran joaosaffran requested a review from inbelic July 22, 2025 17:00
@joaosaffran joaosaffran changed the base branch from users/joaosaffran/149221 to main July 24, 2025 00:03
Copy link

github-actions bot commented Jul 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high level feedback. In general we should try to avoid taking ownership and copying around strings. Since most common uses of strings in LLVM are either static constants or associated with data owned by the module you can usually use StringRef over std::string.

Overall I think this is the right direction.

mcdxbc::RootSignatureDesc &RSD,
MDNode *RootFlagNode) {

llvm::Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this source file is part of LLVM, we should just have a using namespace llvm declaration to avoid needing to explicitly qualify all the LLVM types and methods.

That also makes it more consistent since you have functions like this that are in types declared within the llvm namespace so you can inconsistently use unqualified or qualified names (see below where you call make_error but then llvm::Error::success()).

Comment on lines 123 to 124
if (auto Err = RSDOrErr.takeError()) {
reportError(Ctx, toString(std::move(Err)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general with errors especially if they can be more than one error you don't just want to call toString, you should handle each error.

I think this is about right, but you'll need to try it:

Suggested change
if (auto Err = RSDOrErr.takeError()) {
reportError(Ctx, toString(std::move(Err)));
if (!RSDOrErr) {
handleAllErrors(RSDOrErr.takeError(), [&](ErrorInfoBase &EIB) {
Ctx->emitError(EI.message());
});

@@ -2,7 +2,7 @@

target triple = "dxil-unknown-shadermodel6.0-compute"

; CHECK: error: Invalid Descriptor Range type: Invalid
; CHECK: error: Invalid Descriptor Range type:Invalid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a meaningfully worse printing.

@joaosaffran joaosaffran requested a review from llvm-beanz July 25, 2025 21:59
Comment on lines 711 to 713
if (auto Err = parseRootSignatureElement(RSD, Element)) {
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (auto Err = parseRootSignatureElement(RSD, Element)) {
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}
if (auto Err = parseRootSignatureElement(RSD, Element))
DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));

@joaosaffran joaosaffran requested a review from llvm-beanz July 29, 2025 17:53
Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple small nits, but otherwise I think this is fine.

@@ -14,6 +14,8 @@
#ifndef LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H
#define LLVM_FRONTEND_HLSL_ROOTSIGNATUREMETADATA_H

#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include "llvm/ADT/Twine.h"

assert(ElementKind == RootSignatureElementKind::SRV ||
ElementKind == RootSignatureElementKind::UAV ||
ElementKind == RootSignatureElementKind::CBV &&
"parseRootDescriptors should only be called with RootDescriptor "
"parseRootDescriptors should only be called with RootDescriptor"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String concat in tokenization does not add spaces, so if this ever gets printed because the assert gets hit this will look oddly garbled without the trailing space.

Suggested change
"parseRootDescriptors should only be called with RootDescriptor"
"parseRootDescriptors should only be called with RootDescriptor "

@@ -12,16 +12,23 @@
//===----------------------------------------------------------------------===//

#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/ADT/Twine.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe this is used anywhere.

Suggested change
#include "llvm/ADT/Twine.h"

Joao Saffran added 2 commits July 30, 2025 11:27
Comment on lines +26 to +29
char GenericRSMetadataError::ID;
char InvalidRSMetadataFormat::ID;
char InvalidRSMetadataValue::ID;
template <typename T> char RootSignatureValidationError<T>::ID;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not immediately clear to me why these are required?

Maybe it is part of a pattern I haven't seen before. But just wanted to check

@joaosaffran joaosaffran merged commit b2e5303 into llvm:main Aug 1, 2025
10 checks passed
krishna2803 pushed a commit to krishna2803/llvm-project that referenced this pull request Aug 12, 2025
…lvm#149232)

This PR addresses
llvm#144465 (comment).
Using `joinErrors` and `llvm:Error` instead of boolean values.

---------

Co-authored-by: joaosaffran <[email protected]>
Co-authored-by: Joao Saffran <{ID}+{username}@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants