Skip to content

Conversation

joaosaffran
Copy link
Contributor

We have too many custom error classes that look too much alike when error handling root signature metadata parser. This PR removes those custom error classes and instead reuses StringError and a few pre-formatted error messages.

Closes: #159429

@llvmbot llvmbot added the HLSL HLSL Language Support label Oct 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-hlsl

Author: None (joaosaffran)

Changes

We have too many custom error classes that look too much alike when error handling root signature metadata parser. This PR removes those custom error classes and instead reuses StringError and a few pre-formatted error messages.

Closes: #159429


Patch is 36.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161921.diff

2 Files Affected:

  • (modified) llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h (+104-156)
  • (modified) llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp (+203-111)
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
index bfcbf728d415c..76d9e93e2470a 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h
@@ -28,164 +28,112 @@ class Metadata;
 namespace hlsl {
 namespace rootsig {
 
-template <typename T>
-class RootSignatureValidationError
-    : public ErrorInfo<RootSignatureValidationError<T>> {
-public:
-  static char ID;
-  StringRef ParamName;
-  T Value;
-
-  RootSignatureValidationError(StringRef ParamName, T Value)
-      : ParamName(ParamName), Value(Value) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << "Invalid value for " << ParamName << ": " << Value;
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-};
-
-class OffsetAppendAfterOverflow : public ErrorInfo<OffsetAppendAfterOverflow> {
-public:
-  static char ID;
-  dxil::ResourceClass Type;
-  uint32_t Register;
-  uint32_t Space;
-
-  OffsetAppendAfterOverflow(dxil::ResourceClass Type, uint32_t Register,
-                            uint32_t Space)
-      : Type(Type), Register(Register), Space(Space) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << "Range " << getResourceClassName(Type) << "(register=" << Register
-       << ", space=" << Space << ") "
-       << "cannot be appended after an unbounded range ";
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-};
-
-class ShaderRegisterOverflowError
-    : public ErrorInfo<ShaderRegisterOverflowError> {
-public:
-  static char ID;
-  dxil::ResourceClass Type;
-  uint32_t Register;
-  uint32_t Space;
-
-  ShaderRegisterOverflowError(dxil::ResourceClass Type, uint32_t Register,
-                              uint32_t Space)
-      : Type(Type), Register(Register), Space(Space) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << "Overflow for shader register range: " << getResourceClassName(Type)
-       << "(register=" << Register << ", space=" << Space << ").";
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-};
-
-class OffsetOverflowError : public ErrorInfo<OffsetOverflowError> {
-public:
-  static char ID;
-  dxil::ResourceClass Type;
-  uint32_t Register;
-  uint32_t Space;
-
-  OffsetOverflowError(dxil::ResourceClass Type, uint32_t Register,
-                      uint32_t Space)
-      : Type(Type), Register(Register), Space(Space) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << "Offset overflow for descriptor range: " << getResourceClassName(Type)
-       << "(register=" << Register << ", space=" << Space << ").";
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
+enum class RSErrorKind {
+  Validation,
+  AppendAfterUnboundedRange,
+  ShaderRegisterOverflow,
+  OffsetOverflow,
+  SamplerMixin,
+  GenericMetadata,
+  InvalidMetadataFormat,
+  InvalidMetadataValue
 };
 
-class TableSamplerMixinError : public ErrorInfo<TableSamplerMixinError> {
-public:
-  static char ID;
-  dxil::ResourceClass Type;
-  uint32_t Location;
-
-  TableSamplerMixinError(dxil::ResourceClass Type, uint32_t Location)
-      : Type(Type), Location(Location) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << "Samplers cannot be mixed with other "
-       << "resource types in a descriptor table, " << getResourceClassName(Type)
-       << "(location=" << Location << ")";
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-};
-
-class GenericRSMetadataError : public ErrorInfo<GenericRSMetadataError> {
-public:
-  LLVM_ABI static char ID;
-  StringRef Message;
-  MDNode *MD;
-
-  GenericRSMetadataError(StringRef Message, MDNode *MD)
-      : Message(Message), MD(MD) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << Message;
-    if (MD) {
-      OS << "\n";
-      MD->printTree(OS);
-    }
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-};
-
-class InvalidRSMetadataFormat : public ErrorInfo<InvalidRSMetadataFormat> {
-public:
-  LLVM_ABI static char ID;
-  StringRef ElementName;
-
-  InvalidRSMetadataFormat(StringRef ElementName) : ElementName(ElementName) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << "Invalid format for  " << ElementName;
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-};
-
-class InvalidRSMetadataValue : public ErrorInfo<InvalidRSMetadataValue> {
-public:
-  LLVM_ABI static char ID;
-  StringRef ParamName;
-
-  InvalidRSMetadataValue(StringRef ParamName) : ParamName(ParamName) {}
-
-  void log(raw_ostream &OS) const override {
-    OS << "Invalid value for " << ParamName;
-  }
-
-  std::error_code convertToErrorCode() const override {
-    return llvm::inconvertibleErrorCode();
-  }
-};
+template <typename T>
+void formatImpl(raw_string_ostream &Buff,
+                std::integral_constant<RSErrorKind, RSErrorKind::Validation>,
+                StringRef ParamName, T Value);
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::AppendAfterUnboundedRange>,
+    dxil::ResourceClass Type, uint32_t Register, uint32_t Space);
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::ShaderRegisterOverflow>,
+    dxil::ResourceClass Type, uint32_t Register, uint32_t Space);
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::OffsetOverflow>,
+    dxil::ResourceClass Type, uint32_t Register, uint32_t Space);
+
+void formatImpl(raw_string_ostream &Buff,
+                std::integral_constant<RSErrorKind, RSErrorKind::SamplerMixin>,
+                dxil::ResourceClass Type, uint32_t Location);
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::InvalidMetadataFormat>,
+    StringRef ElementName);
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::InvalidMetadataValue>,
+    StringRef ParamName);
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::GenericMetadata>,
+    StringRef Message, MDNode *MD);
+
+template <typename... ArgsTs>
+inline void formatImpl(raw_string_ostream &Buff, RSErrorKind Kind,
+                       ArgsTs... Args) {
+  switch (Kind) {
+  case RSErrorKind::Validation:
+    return formatImpl(
+        Buff, std::integral_constant<RSErrorKind, RSErrorKind::Validation>(),
+        Args...);
+  case RSErrorKind::AppendAfterUnboundedRange:
+    return formatImpl(
+        Buff,
+        std::integral_constant<RSErrorKind,
+                               RSErrorKind::AppendAfterUnboundedRange>(),
+        Args...);
+  case RSErrorKind::ShaderRegisterOverflow:
+    return formatImpl(
+        Buff,
+        std::integral_constant<RSErrorKind,
+                               RSErrorKind::ShaderRegisterOverflow>(),
+        Args...);
+  case RSErrorKind::OffsetOverflow:
+    return formatImpl(
+        Buff,
+        std::integral_constant<RSErrorKind, RSErrorKind::OffsetOverflow>(),
+        Args...);
+  case RSErrorKind::SamplerMixin:
+    return formatImpl(
+        Buff, std::integral_constant<RSErrorKind, RSErrorKind::SamplerMixin>(),
+        Args...);
+  case RSErrorKind::InvalidMetadataFormat:
+    return formatImpl(
+        Buff,
+        std::integral_constant<RSErrorKind,
+                               RSErrorKind::InvalidMetadataFormat>(),
+        Args...);
+  case RSErrorKind::InvalidMetadataValue:
+    return formatImpl(
+        Buff,
+        std::integral_constant<RSErrorKind,
+                               RSErrorKind::InvalidMetadataValue>(),
+        Args...);
+  case RSErrorKind::GenericMetadata:
+    return formatImpl(
+        Buff,
+        std::integral_constant<RSErrorKind, RSErrorKind::GenericMetadata>(),
+        Args...);
+  }
+}
+
+template <typename... ArgsTs>
+static llvm::Error createRSError(RSErrorKind Kind, ArgsTs... Args) {
+  std::string Msg;
+  raw_string_ostream Buff(Msg);
+  formatImpl(Buff, Kind, Args...);
+  return createStringError(std::move(Buff.str()), inconvertibleErrorCode());
+}
 
 class MetadataBuilder {
 public:
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 7a0cf408968de..7ed4ec0dc3457 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -24,15 +24,75 @@ namespace llvm {
 namespace hlsl {
 namespace rootsig {
 
-char GenericRSMetadataError::ID;
-char InvalidRSMetadataFormat::ID;
-char InvalidRSMetadataValue::ID;
-char TableSamplerMixinError::ID;
-char ShaderRegisterOverflowError::ID;
-char OffsetOverflowError::ID;
-char OffsetAppendAfterOverflow::ID;
+template <typename T>
+void formatImpl(raw_string_ostream &Buff,
+                std::integral_constant<RSErrorKind, RSErrorKind::Validation>,
+                StringRef ParamName, T Value) {
+  Buff << "Invalid value for: " << ParamName << ":";
+  if constexpr (std::is_same_v<std::decay_t<T>, std::nullptr_t>) {
+    Buff << "nullptr";
+  } else {
+    Buff << Value;
+  }
+}
 
-template <typename T> char RootSignatureValidationError<T>::ID;
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::AppendAfterUnboundedRange>,
+    dxil::ResourceClass Type, uint32_t Register, uint32_t Space) {
+  Buff << "Range " << getResourceClassName(Type) << "(register=" << Register
+       << ", space=" << Space << ") "
+       << "cannot be appended after an unbounded range ";
+}
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::ShaderRegisterOverflow>,
+    dxil::ResourceClass Type, uint32_t Register, uint32_t Space) {
+  Buff << "Overflow for shader register range: " << getResourceClassName(Type)
+       << "(register=" << Register << ", space=" << Space << ").";
+}
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::OffsetOverflow>,
+    dxil::ResourceClass Type, uint32_t Register, uint32_t Space) {
+  Buff << "Offset overflow for descriptor range: " << getResourceClassName(Type)
+       << "(register=" << Register << ", space=" << Space << ").";
+}
+
+void formatImpl(raw_string_ostream &Buff,
+                std::integral_constant<RSErrorKind, RSErrorKind::SamplerMixin>,
+                dxil::ResourceClass Type, uint32_t Location) {
+  Buff << "Samplers cannot be mixed with other "
+       << "resource types in a descriptor table, " << getResourceClassName(Type)
+       << "(location=" << Location << ")";
+}
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::InvalidMetadataFormat>,
+    StringRef ElementName) {
+  Buff << "Invalid format for  " << ElementName;
+}
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::InvalidMetadataValue>,
+    StringRef ParamName) {
+  Buff << "Invalid value for " << ParamName;
+}
+
+void formatImpl(
+    raw_string_ostream &Buff,
+    std::integral_constant<RSErrorKind, RSErrorKind::GenericMetadata>,
+    StringRef Message, MDNode *MD) {
+  Buff << Message;
+  if (MD) {
+    Buff << "\n";
+    MD->printTree(Buff);
+  }
+}
 
 static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
                                                  unsigned int OpId) {
@@ -65,10 +125,11 @@ extractEnumValue(MDNode *Node, unsigned int OpId, StringRef ErrText,
                  llvm::function_ref<bool(uint32_t)> VerifyFn) {
   if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
     if (!VerifyFn(*Val))
-      return make_error<RootSignatureValidationError<uint32_t>>(ErrText, *Val);
+      return createRSError(RSErrorKind::Validation, ErrText, *Val);
     return static_cast<T>(*Val);
   }
-  return make_error<InvalidRSMetadataValue>("ShaderVisibility");
+  return createRSError(RSErrorKind::InvalidMetadataValue,
+                       StringRef("ShaderVisibility"));
 }
 
 namespace {
@@ -226,12 +287,14 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
 Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
                                      MDNode *RootFlagNode) {
   if (RootFlagNode->getNumOperands() != 2)
-    return make_error<InvalidRSMetadataFormat>("RootFlag Element");
+    return createRSError(RSErrorKind::InvalidMetadataFormat,
+                         StringRef("RootFlags Element"));
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
     RSD.Flags = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("RootFlag");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("RootFlag"));
 
   return Error::success();
 }
@@ -239,7 +302,8 @@ Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
 Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
                                          MDNode *RootConstantNode) {
   if (RootConstantNode->getNumOperands() != 5)
-    return make_error<InvalidRSMetadataFormat>("RootConstants Element");
+    return createRSError(RSErrorKind::InvalidMetadataFormat,
+                         StringRef("RootConstants Element"));
 
   Expected<dxbc::ShaderVisibility> Visibility =
       extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1,
@@ -252,17 +316,20 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
     Constants.ShaderRegister = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("ShaderRegister");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("ShaderRegister"));
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
     Constants.RegisterSpace = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("RegisterSpace");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("RegisterSpace"));
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
     Constants.Num32BitValues = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("Num32BitValues");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("Num32BitValues"));
 
   RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
                                        *Visibility, Constants);
@@ -279,7 +346,8 @@ Error MetadataParser::parseRootDescriptors(
          "parseRootDescriptors should only be called with RootDescriptor "
          "element kind.");
   if (RootDescriptorNode->getNumOperands() != 5)
-    return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
+    return createRSError(RSErrorKind::InvalidMetadataFormat,
+                         StringRef("Root Descriptor Element"));
 
   dxbc::RootParameterType Type;
   switch (ElementKind) {
@@ -308,12 +376,14 @@ Error MetadataParser::parseRootDescriptors(
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
     Descriptor.ShaderRegister = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("ShaderRegister");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("ShaderRegister"));
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
     Descriptor.RegisterSpace = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("RegisterSpace");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("RegisterSpace"));
 
   if (RSD.Version == 1) {
     RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
@@ -324,7 +394,8 @@ Error MetadataParser::parseRootDescriptors(
   if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
     Descriptor.Flags = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("Root Descriptor Flags"));
 
   RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
   return Error::success();
@@ -333,7 +404,8 @@ Error MetadataParser::parseRootDescriptors(
 Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
                                            MDNode *RangeDescriptorNode) {
   if (RangeDescriptorNode->getNumOperands() != 6)
-    return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+    return createRSError(RSErrorKind::InvalidMetadataFormat,
+                         StringRef("Descriptor Range"));
 
   mcdxbc::DescriptorRange Range;
 
@@ -341,7 +413,8 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
       extractMdStringValue(RangeDescriptorNode, 0);
 
   if (!ElementText.has_value())
-    return make_error<InvalidRSMetadataFormat>("Descriptor Range");
+    return createRSError(RSErrorKind::InvalidMetadataFormat,
+                         StringRef("Descriptor Range"));
 
   if (*ElementText == "CBV")
     Range.RangeType = dxil::ResourceClass::CBuffer;
@@ -352,35 +425,40 @@ Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
   else if (*ElementText == "Sampler")
     Range.RangeType = dxil::ResourceClass::Sampler;
   else
-    return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.",
-                                              RangeDescriptorNode);
+    return createRSError(RSErrorKind::GenericMetadata,
+                         StringRef("Invalid Descriptor Range type."),
+                         RangeDescriptorNode);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
     Range.NumDescriptors = *Val;
   else
-    return make_error<GenericRSMetadataError>("Number of Descriptor in Range",
-                                              RangeDescriptorNode);
+    return createRSError(RSErrorKind::GenericMetadata,
+                         StringRef("Number of Descriptor in Range"),
+                         RangeDescriptorNode);
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
     Range.BaseShaderRegister = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("BaseShaderRegister"));
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
     Range.RegisterSpace = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("RegisterSpace");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("RegisterSpace"));
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
     Range.OffsetInDescriptorsFromTableStart = *Val;
   else
-    return make_error<InvalidRSMetadataValue>(
-        "OffsetInDescriptorsFromTableStart");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("OffsetInDescriptorsFromTableStart"));
 
   if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
     Range.Flags = *Val;
   else
-    return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
+    return createRSError(RSErrorKind::InvalidMetadataValue,
+                         StringRef("Descriptor Range Flags"));
 
   Table.Ranges.push_back(Range);
   return Error::success();
@@ -390,7 +468,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
                                            MDNode *DescriptorTableNode) {
   const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
   if (NumOperands < 2)
-    return make_error<InvalidRSMetadataFormat>("Descriptor Table");
+    return createRSError(RSErrorKind::InvalidMetadataFormat,
+                     ...
[truncated]

Copy link

github-actions bot commented Oct 6, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions h,cpp -- llvm/include/llvm/Frontend/HLSL/RootSignatureMetadata.h llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index d58a17cd5..a65651871 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -823,10 +823,10 @@ Error MetadataParser::validateRootSignature(
                                               Sampler.ShaderRegister));
 
     if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
-      DeferredErrs =
-          joinErrors(std::move(DeferredErrs),
-                     createRSError(RSErrorKind::Validation, StringRef(
-                         "RegisterSpace"), Sampler.RegisterSpace));
+      DeferredErrs = joinErrors(std::move(DeferredErrs),
+                                createRSError(RSErrorKind::Validation,
+                                              StringRef("RegisterSpace"),
+                                              Sampler.RegisterSpace));
     bool IsValidFlag =
         dxbc::isValidStaticSamplerFlags(Sampler.Flags) &&
         hlsl::rootsig::verifyStaticSamplerFlags(

Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that replacing all of the error classes with enum values makes this any simpler. I was thinking that we could probably just have one error class, like:

class RootSignatureValidationError
    : public ErrorInfo<RootSignatureValidationError> {
public:
  static char ID;
  std::string Msg;

  RootSignatureValidationError(const Twine &Msg) : Msg(Msg.str()) {}

  void log(raw_ostream &OS) const override { OS << Msg; }

  std::error_code convertToErrorCode() const override {
    return llvm::inconvertibleErrorCode();
  }
};

This could just be StringError, but I think having one class here does make things clearer.

Then, instead of the error class heirarchy being responsible for all of the error messages, we could just use formatv when it isn't a fixed string:

static Error makeRSError(const Twine &Msg) {
  return make_error<RootSignatureValidationError>(Msg);
}
// ...
makeRSError("Invalid format for RootFlag Element");
makeRSError("Invalid value for RootFlag");
makeRSError(
    formatv("Samplers cannot be mixed with other resource "
            "types in a descriptor table, {0}(location={1})",
            getResourceClassName(CurrRC), Location));
makeRSError(formatv("Invalid value for Version: {0}", RSD.Version))

This might make the versions that print more structured data a bit harder to follow, but you can work around that with a local format adaptor object:

namespace {
struct FmtRange {
  dxil::ResourceClass Type;
  uint32_t Register;
  uint32_t Space;

  FmtRange(const mcdxbc::DescriptorRange &Range)
           : Type(Range.RangeType), Register(Range.BaseShaderRegister),
             Space(Range.RegisterSpace) {}
};
raw_ostream &operator<<(llvm::raw_ostream &OS, const FmtRange &Range) {
  OS << getResourceClassName(Range.Type) << "(register=" << Range.Register
     << ", space=" << Range.Space << ")";
  return OS;
}
}
// ...
makeRSError(formatv("Overflow for shader register range: {0}",
                    FmtRange(Range)));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DirectX] Simplify errors from RootSignatureMetadata.h
3 participants