Skip to content

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Jul 7, 2025

This pr fixes a bug that allows parameters to be specified without an intermediate comma.

After this pr, we will correctly produce a diagnostic for (eg):

RootFlags(0) CBV(b0)

This pr updates the problematic code pattern containing a chain of 'if' statements to a chain of 'else if' statements, to prevent parsing of an element before checking for a comma.

This pr also does 2 small updates, while in the region:

  1. Simplify the do loop that these if statements are contained in. This helps code readability and makes it easier to improve the diagnostics further
  2. Moves the consumeExpectedToken function calls to be right after the parse.*Params invocation. This will ensure that the comma or invalid token error is presented before a "missed mandatory param" diagnostic.
  • Updates all occurrences of the if chains with an else-if chain
  • Simplifies the surrounding do loop to be an easier to understand while loop
  • Moves the consumeExpectedToken diagnostic right after the loop so that the missing comma diagnostic is produce before checking for any missed mandatory arguments
  • Adds unit tests for this scenario
  • Small fix to the diagnostic of RootDescriptors to use their respective Token instead of RootConstants

Resolves: #147337

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Jul 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-clang

Author: Finn Plummer (inbelic)

Changes

TODO: will fill in.

Resolves: #147337


Patch is 47.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147350.diff

5 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+1-2)
  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+10-4)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+198-289)
  • (modified) clang/test/SemaHLSL/RootSignature-err.hlsl (+13-5)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+164-2)
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 6c30da376dafb..7c636d9049762 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1857,8 +1857,7 @@ def err_hlsl_virtual_inheritance
     : Error<"virtual inheritance is unsupported in HLSL">;
 
 // HLSL Root Signature Parser Diagnostics
-def err_hlsl_unexpected_end_of_params
-    : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_invalid_param : Error<"invalid parameter of %0">;
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
 def err_hlsl_number_literal_overflow : Error<
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 9ef5b64d7b4a5..2846b83ee959b 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -100,7 +100,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
   };
   std::optional<ParsedRootDescriptorParams>
-  parseRootDescriptorParams(RootSignatureToken::Kind RegType);
+  parseRootDescriptorParams(RootSignatureToken::Kind DescType,
+                            RootSignatureToken::Kind RegType);
 
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -110,7 +111,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
   };
   std::optional<ParsedClauseParams>
-  parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
+  parseDescriptorTableClauseParams(RootSignatureToken::Kind DescType,
+                                   RootSignatureToken::Kind RegType);
 
   struct ParsedStaticSamplerParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -178,8 +180,8 @@ class RootSignatureParser {
   ///
   /// Returns true if there was an error reported.
   bool consumeExpectedToken(
-      RootSignatureToken::Kind Expected, unsigned DiagID = diag::err_expected,
-      RootSignatureToken::Kind Context = RootSignatureToken::Kind::invalid);
+      RootSignatureToken::Kind Expected,
+      std::optional<RootSignatureToken::Kind> Context = std::nullopt);
 
   /// Peek if the next token is of the expected kind and if it is then consume
   /// it.
@@ -195,6 +197,10 @@ class RootSignatureParser {
   /// StringLiterals
   SourceLocation getTokenLocation(RootSignatureToken Tok);
 
+  DiagnosticBuilder reportDiag(unsigned DiagID) {
+    return getDiags().Report(getTokenLocation(CurToken), DiagID);
+  }
+
 private:
   llvm::dxbc::RootSignatureVersion Version;
   SmallVector<RootSignatureElement> &Elements;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 86dd01c1a2841..582445b876c3f 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -25,58 +25,62 @@ RootSignatureParser::RootSignatureParser(
       Lexer(Signature->getString()), PP(PP), CurToken(0) {}
 
 bool RootSignatureParser::parse() {
-  // Iterate as many RootSignatureElements as possible
-  do {
+  // Iterate as many RootSignatureElements as possible, until we hit the
+  // end of the stream
+  while (!peekExpectedToken(TokenKind::end_of_stream)) {
     std::optional<RootSignatureElement> Element = std::nullopt;
     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+      // RootFlags
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Flags = parseRootFlags();
       if (!Flags.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Flags);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+      // RootConstants
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Constants = parseRootConstants();
       if (!Constants.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Constants);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+      // DescriptorTable
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Table = parseDescriptorTable();
       if (!Table.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Table);
-    }
-
-    if (tryConsumeExpectedToken(
-            {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+    } else if (tryConsumeExpectedToken(
+                   {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+      // RootDescriptor - CBV, SRV, UAV
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Descriptor = parseRootDescriptor();
       if (!Descriptor.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Descriptor);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+      // StaticSampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Sampler = parseStaticSampler();
       if (!Sampler.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Sampler);
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootSignature;
+      return true;
     }
 
     if (Element.has_value())
       Elements.push_back(*Element);
 
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+    // ',' denotes another element, otherwise, expected to be at end of stream
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
-  return consumeExpectedToken(TokenKind::end_of_stream,
-                              diag::err_hlsl_unexpected_end_of_params,
-                              /*param of=*/TokenKind::kw_RootSignature);
+  return consumeExpectedToken(TokenKind::end_of_stream);
 }
 
 template <typename FlagType>
@@ -92,8 +96,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
@@ -101,8 +104,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   // Handle the edge-case of '0' to specify no flags set
   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
     if (!verifyZeroFlag()) {
-      getDiags().Report(getTokenLocation(CurToken),
-                        diag::err_hlsl_rootsig_non_zero_flag);
+      reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
       return std::nullopt;
     }
   } else {
@@ -128,9 +130,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
     } while (tryConsumeExpectedToken(TokenKind::pu_or));
   }
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootFlags))
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
     return std::nullopt;
 
   return Flags;
@@ -140,8 +140,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootConstants Constants;
@@ -150,10 +149,12 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters where provided
   if (!Params->Num32BitConstants.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
+    reportDiag(diag::err_hlsl_rootsig_missing_param)
         << TokenKind::kw_num32BitConstants;
     return std::nullopt;
   }
@@ -161,9 +162,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   Constants.Num32BitConstants = Params->Num32BitConstants.value();
 
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::bReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
     return std::nullopt;
   }
 
@@ -176,11 +175,6 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (Params->Space.has_value())
     Constants.Space = Params->Space.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Constants;
 }
 
@@ -192,8 +186,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
 
   TokenKind DescriptorKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootDescriptor Descriptor;
@@ -216,15 +209,16 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   }
   Descriptor.setDefaultFlags(Version);
 
-  auto Params = parseRootDescriptorParams(ExpectedReg);
+  auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -240,11 +234,6 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   if (Params->Flags.has_value())
     Descriptor.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Descriptor;
 }
 
@@ -252,30 +241,27 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
   assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTable Table;
   std::optional<llvm::dxbc::ShaderVisibility> Visibility;
 
-  // Iterate as many Clauses as possible
-  do {
+  // Iterate as many Clauses as possible, until we hit ')'
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+      // DescriptorTableClause - CBV, SRV, UAV, or Sampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Clause = parseDescriptorTableClause();
       if (!Clause.has_value())
         return std::nullopt;
       Elements.push_back(RootSignatureElement(ElementLoc, *Clause));
       Table.NumClauses++;
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // visibility = SHADER_VISIBILITY
       if (Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -285,18 +271,25 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
       Visibility = parseShaderVisibility();
       if (!Visibility.has_value())
         return std::nullopt;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_DescriptorTable;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
 
   // Fill in optional visibility
   if (Visibility.has_value())
     Table.Visibility = Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_DescriptorTable))
-    return std::nullopt;
-
   return Table;
 }
 
@@ -310,8 +303,7 @@ RootSignatureParser::parseDescriptorTableClause() {
 
   TokenKind ParamKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTableClause Clause;
@@ -338,15 +330,16 @@ RootSignatureParser::parseDescriptorTableClause() {
   }
   Clause.setDefaultFlags(Version);
 
-  auto Params = parseDescriptorTableClauseParams(ExpectedReg);
+  auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -365,11 +358,6 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Flags.has_value())
     Clause.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/ParamKind))
-    return std::nullopt;
-
   return Clause;
 }
 
@@ -377,8 +365,7 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   StaticSampler Sampler;
@@ -387,11 +374,12 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::sReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
     return std::nullopt;
   }
 
@@ -434,11 +422,6 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (Params->Visibility.has_value())
     Sampler.Visibility = Params->Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_StaticSampler))
-    return std::nullopt;
-
   return Sampler;
 }
 
@@ -451,13 +434,11 @@ RootSignatureParser::parseRootConstantParams() {
          "Expects to only be invoked starting at given token");
 
   ParsedConstantParams Params;
-  do {
-    // `num32BitConstants` `=` POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      // `num32BitConstants` `=` POS_INT
       if (Params.Num32BitConstants.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -468,28 +449,20 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Num32BitConstants.has_value())
         return std::nullopt;
       Params.Num32BitConstants = Num32BitConstants;
-    }
-
-    // `b` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+    } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      // `b` POS_INT
       if (Params.Reg.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
       auto Reg = parseRegister();
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -500,14 +473,10 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `visibility` `=` SHADER_VISIBILITY
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // `visibility` `=` SHADER_VISIBILITY
       if (Params.Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -518,39 +487,43 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Visibility.has_value())
         return std::nullopt;
       Params.Visibility = Visibility;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootConstants;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }
 
 std::optional<RootSignatureParser::ParsedRootDescriptorParams>
-RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
+RootSignatureParser::parseRootDescriptorParams(To...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

Changes

TODO: will fill in.

Resolves: #147337


Patch is 47.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147350.diff

5 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+1-2)
  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+10-4)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+198-289)
  • (modified) clang/test/SemaHLSL/RootSignature-err.hlsl (+13-5)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+164-2)
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 6c30da376dafb..7c636d9049762 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1857,8 +1857,7 @@ def err_hlsl_virtual_inheritance
     : Error<"virtual inheritance is unsupported in HLSL">;
 
 // HLSL Root Signature Parser Diagnostics
-def err_hlsl_unexpected_end_of_params
-    : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_invalid_param : Error<"invalid parameter of %0">;
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
 def err_hlsl_number_literal_overflow : Error<
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 9ef5b64d7b4a5..2846b83ee959b 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -100,7 +100,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
   };
   std::optional<ParsedRootDescriptorParams>
-  parseRootDescriptorParams(RootSignatureToken::Kind RegType);
+  parseRootDescriptorParams(RootSignatureToken::Kind DescType,
+                            RootSignatureToken::Kind RegType);
 
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -110,7 +111,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
   };
   std::optional<ParsedClauseParams>
-  parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
+  parseDescriptorTableClauseParams(RootSignatureToken::Kind DescType,
+                                   RootSignatureToken::Kind RegType);
 
   struct ParsedStaticSamplerParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -178,8 +180,8 @@ class RootSignatureParser {
   ///
   /// Returns true if there was an error reported.
   bool consumeExpectedToken(
-      RootSignatureToken::Kind Expected, unsigned DiagID = diag::err_expected,
-      RootSignatureToken::Kind Context = RootSignatureToken::Kind::invalid);
+      RootSignatureToken::Kind Expected,
+      std::optional<RootSignatureToken::Kind> Context = std::nullopt);
 
   /// Peek if the next token is of the expected kind and if it is then consume
   /// it.
@@ -195,6 +197,10 @@ class RootSignatureParser {
   /// StringLiterals
   SourceLocation getTokenLocation(RootSignatureToken Tok);
 
+  DiagnosticBuilder reportDiag(unsigned DiagID) {
+    return getDiags().Report(getTokenLocation(CurToken), DiagID);
+  }
+
 private:
   llvm::dxbc::RootSignatureVersion Version;
   SmallVector<RootSignatureElement> &Elements;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 86dd01c1a2841..582445b876c3f 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -25,58 +25,62 @@ RootSignatureParser::RootSignatureParser(
       Lexer(Signature->getString()), PP(PP), CurToken(0) {}
 
 bool RootSignatureParser::parse() {
-  // Iterate as many RootSignatureElements as possible
-  do {
+  // Iterate as many RootSignatureElements as possible, until we hit the
+  // end of the stream
+  while (!peekExpectedToken(TokenKind::end_of_stream)) {
     std::optional<RootSignatureElement> Element = std::nullopt;
     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+      // RootFlags
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Flags = parseRootFlags();
       if (!Flags.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Flags);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+      // RootConstants
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Constants = parseRootConstants();
       if (!Constants.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Constants);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+      // DescriptorTable
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Table = parseDescriptorTable();
       if (!Table.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Table);
-    }
-
-    if (tryConsumeExpectedToken(
-            {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+    } else if (tryConsumeExpectedToken(
+                   {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+      // RootDescriptor - CBV, SRV, UAV
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Descriptor = parseRootDescriptor();
       if (!Descriptor.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Descriptor);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+      // StaticSampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Sampler = parseStaticSampler();
       if (!Sampler.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Sampler);
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootSignature;
+      return true;
     }
 
     if (Element.has_value())
       Elements.push_back(*Element);
 
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+    // ',' denotes another element, otherwise, expected to be at end of stream
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
-  return consumeExpectedToken(TokenKind::end_of_stream,
-                              diag::err_hlsl_unexpected_end_of_params,
-                              /*param of=*/TokenKind::kw_RootSignature);
+  return consumeExpectedToken(TokenKind::end_of_stream);
 }
 
 template <typename FlagType>
@@ -92,8 +96,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
@@ -101,8 +104,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   // Handle the edge-case of '0' to specify no flags set
   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
     if (!verifyZeroFlag()) {
-      getDiags().Report(getTokenLocation(CurToken),
-                        diag::err_hlsl_rootsig_non_zero_flag);
+      reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
       return std::nullopt;
     }
   } else {
@@ -128,9 +130,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
     } while (tryConsumeExpectedToken(TokenKind::pu_or));
   }
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootFlags))
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
     return std::nullopt;
 
   return Flags;
@@ -140,8 +140,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootConstants Constants;
@@ -150,10 +149,12 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters where provided
   if (!Params->Num32BitConstants.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
+    reportDiag(diag::err_hlsl_rootsig_missing_param)
         << TokenKind::kw_num32BitConstants;
     return std::nullopt;
   }
@@ -161,9 +162,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   Constants.Num32BitConstants = Params->Num32BitConstants.value();
 
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::bReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
     return std::nullopt;
   }
 
@@ -176,11 +175,6 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (Params->Space.has_value())
     Constants.Space = Params->Space.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Constants;
 }
 
@@ -192,8 +186,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
 
   TokenKind DescriptorKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootDescriptor Descriptor;
@@ -216,15 +209,16 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   }
   Descriptor.setDefaultFlags(Version);
 
-  auto Params = parseRootDescriptorParams(ExpectedReg);
+  auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -240,11 +234,6 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   if (Params->Flags.has_value())
     Descriptor.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Descriptor;
 }
 
@@ -252,30 +241,27 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
   assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTable Table;
   std::optional<llvm::dxbc::ShaderVisibility> Visibility;
 
-  // Iterate as many Clauses as possible
-  do {
+  // Iterate as many Clauses as possible, until we hit ')'
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+      // DescriptorTableClause - CBV, SRV, UAV, or Sampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Clause = parseDescriptorTableClause();
       if (!Clause.has_value())
         return std::nullopt;
       Elements.push_back(RootSignatureElement(ElementLoc, *Clause));
       Table.NumClauses++;
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // visibility = SHADER_VISIBILITY
       if (Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -285,18 +271,25 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
       Visibility = parseShaderVisibility();
       if (!Visibility.has_value())
         return std::nullopt;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_DescriptorTable;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
 
   // Fill in optional visibility
   if (Visibility.has_value())
     Table.Visibility = Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_DescriptorTable))
-    return std::nullopt;
-
   return Table;
 }
 
@@ -310,8 +303,7 @@ RootSignatureParser::parseDescriptorTableClause() {
 
   TokenKind ParamKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTableClause Clause;
@@ -338,15 +330,16 @@ RootSignatureParser::parseDescriptorTableClause() {
   }
   Clause.setDefaultFlags(Version);
 
-  auto Params = parseDescriptorTableClauseParams(ExpectedReg);
+  auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -365,11 +358,6 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Flags.has_value())
     Clause.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/ParamKind))
-    return std::nullopt;
-
   return Clause;
 }
 
@@ -377,8 +365,7 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   StaticSampler Sampler;
@@ -387,11 +374,12 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::sReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
     return std::nullopt;
   }
 
@@ -434,11 +422,6 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (Params->Visibility.has_value())
     Sampler.Visibility = Params->Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_StaticSampler))
-    return std::nullopt;
-
   return Sampler;
 }
 
@@ -451,13 +434,11 @@ RootSignatureParser::parseRootConstantParams() {
          "Expects to only be invoked starting at given token");
 
   ParsedConstantParams Params;
-  do {
-    // `num32BitConstants` `=` POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      // `num32BitConstants` `=` POS_INT
       if (Params.Num32BitConstants.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -468,28 +449,20 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Num32BitConstants.has_value())
         return std::nullopt;
       Params.Num32BitConstants = Num32BitConstants;
-    }
-
-    // `b` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+    } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      // `b` POS_INT
       if (Params.Reg.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
       auto Reg = parseRegister();
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -500,14 +473,10 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `visibility` `=` SHADER_VISIBILITY
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // `visibility` `=` SHADER_VISIBILITY
       if (Params.Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -518,39 +487,43 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Visibility.has_value())
         return std::nullopt;
       Params.Visibility = Visibility;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootConstants;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }
 
 std::optional<RootSignatureParser::ParsedRootDescriptorParams>
-RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
+RootSignatureParser::parseRootDescriptorParams(To...
[truncated]

@inbelic inbelic linked an issue Jul 7, 2025 that may be closed by this pull request
2 tasks
@inbelic inbelic marked this pull request as draft July 7, 2025 17:55
@inbelic
Copy link
Contributor Author

inbelic commented Jul 7, 2025

Contemplating if I should split this into two prs. Will see if there is a nice way to de-couple the improve and fix error portions of this.

@inbelic inbelic changed the base branch from users/inbelic/pr-147115 to main July 8, 2025 17:59
@inbelic inbelic force-pushed the inbelic/rs-improve-diags branch from a561510 to dfde6d4 Compare July 8, 2025 17:59
@inbelic inbelic marked this pull request as ready for review July 8, 2025 18:00
@inbelic
Copy link
Contributor Author

inbelic commented Jul 8, 2025

Updated to rebase onto main so that it will merge before #147115. Removes the 'improve diag' portion. I will create a follow-up issue for that to track the improvement of diagnostic.

[RootSignature(MultiLineRootSignature)]
void bad_root_signature_6() {}

// expected-error@+1 {{expected end of stream to denote end of parameters, or, another valid parameter of RootSignature}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this diagnostic just be expected ','? It seems like all the tests flag cases where a comma is expected but not found.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar parsing error in C++ would result in expected ')':

https://godbolt.org/z/z4Gf1Tar6

I think simplifying to expected ',' and/or expected ')' where appropriate will be more understandable to users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think we can simplify the diagnostic here quite a bit.

A similar concern was also noted here: #145827 (comment)

I will create a follow-up issue tomorrow to track this work and do so in a follow-up pr, but will leave this pr to just focus on the bug fix as it has a dependency here: #147115 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the pr and issue to track this.

@inbelic inbelic force-pushed the inbelic/rs-improve-diags branch from 29f7bad to b9cf614 Compare July 9, 2025 15:31
Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do when there are commas at the ends of lists of elements?

Interestingly, DXC seems inconsistent on its behaviour for those:

// Unexpected token ')'
[RootSignature("CBV(b0), CBV(b1,)")]
// valid
[RootSignature("CBV(b0), CBV(b1),")]

I don't know that we need to match this exactly - we should probably be consistent about it. In any case, please do add some tests that make sure we do something sensible.

inbelic added 2 commits July 9, 2025 17:23
this worked before because we returned on the first error found
@inbelic
Copy link
Contributor Author

inbelic commented Jul 9, 2025

Added a test to show that it is consistent in allowing a trailing comma after parameter/values

// - a single trailing comma is allowed after any parameter
// - a trailing comma is not required

[RootSignature("CBV(b0, flags = DATA_VOLATILE,), DescriptorTable(Sampler(s0,),),")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we reject multiple trailing commas? Something like:

[RootSignature("CBV(b0, flags = DATA_VOLATILE,), DescriptorTable(Sampler(s0)),,")]

Copy link
Contributor Author

@inbelic inbelic Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intended behaviour is to reject multiple commas. It was just quite hard to decouple from the implementation of a new diag done here. And we will get an error for multiple commas from this change there:

} else {
consumeNextToken(); // let diagnostic be at the start of invalid token
reportDiag(diag::err_hlsl_invalid_token)
<< /*parameter=*/0 << /*param of*/ TokenKind::kw_RootSignature;
return true;
}

and will be tested here:

https://github.com/llvm/llvm-project/blob/7ec7e32d2ac4945a489d5463b9fb700b0cceff9d/clang/test/SemaHLSL/RootSignature-err.hlsl#L117

For context, the original draft of this pr, had both an improvement of diagnostic but I wanted to trim it down as much as possible to just be a fix, so that we could get it in before the other pr.

@inbelic inbelic merged commit d60da27 into llvm:main Jul 10, 2025
10 checks passed
inbelic added a commit that referenced this pull request Jul 12, 2025
…on (#147800)

This pr audits the diagnostics produced in `RootSignatureParser`
diagnostics.

First, it has been noted more than once that the current
`diag::err_hlsl_unexpected_end_of_params` is not direct and can be
misleading. For instance,
[here](#147350 (comment))
and
[here](#145827 (comment)).

This pr address this by removing this diagnostic and replacing it with a
more direct `diag::err_expected_either`. However, doing so removes the
nuance between the case where it is a missing comma and when it was an
invalid parameter.

Hence, we introduce the `diag::err_hlsl_invalid_token` for the latter
case, which does so in a direct way. Further, we can apply the same
diagnostic to the parsing of parameter values.

As part of this, we see that there was a test gap in testing the
diagnostics produced from `diag::err_expected_after` and for the parsing
of enum/flag values. As such, these are also addressed here to provide
sufficient unit/sema test coverage.

- Removes all uses of `diag::err_hlsl_unexpected_end_of_params` in
`RootSigantureParser`
- Introduce `diag::err_hlsl_invalid_token` to provide a direct
diagnostic
- In each of these cases, replace the diagnostic with either a
`diag::err_hlsl_invalid_token` or `diag::err_expected_either`
accordingly
- Update `HLSLRootSignatureParserTest` to account for diagnostic changes
- Increase test coverage of `HLSLRootSignatureParserTest` for enum/flag
diagnostics
- Increase test coverage of `RootSignatures-err` for enum/flag
diagnostics
- Small fix-up of the `diag::err_expected_after` and add test to
demonstrate usage

Resolves: #147799
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jul 12, 2025
…ic production (#147800)

This pr audits the diagnostics produced in `RootSignatureParser`
diagnostics.

First, it has been noted more than once that the current
`diag::err_hlsl_unexpected_end_of_params` is not direct and can be
misleading. For instance,
[here](llvm/llvm-project#147350 (comment))
and
[here](llvm/llvm-project#145827 (comment)).

This pr address this by removing this diagnostic and replacing it with a
more direct `diag::err_expected_either`. However, doing so removes the
nuance between the case where it is a missing comma and when it was an
invalid parameter.

Hence, we introduce the `diag::err_hlsl_invalid_token` for the latter
case, which does so in a direct way. Further, we can apply the same
diagnostic to the parsing of parameter values.

As part of this, we see that there was a test gap in testing the
diagnostics produced from `diag::err_expected_after` and for the parsing
of enum/flag values. As such, these are also addressed here to provide
sufficient unit/sema test coverage.

- Removes all uses of `diag::err_hlsl_unexpected_end_of_params` in
`RootSigantureParser`
- Introduce `diag::err_hlsl_invalid_token` to provide a direct
diagnostic
- In each of these cases, replace the diagnostic with either a
`diag::err_hlsl_invalid_token` or `diag::err_expected_either`
accordingly
- Update `HLSLRootSignatureParserTest` to account for diagnostic changes
- Increase test coverage of `HLSLRootSignatureParserTest` for enum/flag
diagnostics
- Increase test coverage of `RootSignatures-err` for enum/flag
diagnostics
- Small fix-up of the `diag::err_expected_after` and add test to
demonstrate usage

Resolves: llvm/llvm-project#147799
@inbelic inbelic deleted the inbelic/rs-improve-diags branch July 14, 2025 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL][RootSignature] Incorrectly allows specifying parameters without a comma

4 participants