diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index d9f121030c1fc..d639ca91c002f 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -85,9 +85,13 @@ class RootSignatureParser { std::optional parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType); + // Common parsing methods std::optional parseUIntParam(); std::optional parseRegister(); + /// Parsing methods of various enums + std::optional parseShaderVisibility(); + /// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned /// 32-bit integer std::optional handleUIntLiteral(); diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index 1bf33b8e8329c..8244e91c8f89a 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -52,6 +52,7 @@ std::optional RootSignatureParser::parseDescriptorTable() { return std::nullopt; DescriptorTable Table; + std::optional Visibility; // Iterate as many Clauses as possible do { @@ -63,8 +64,27 @@ std::optional RootSignatureParser::parseDescriptorTable() { Elements.push_back(*Clause); Table.NumClauses++; } + + if (tryConsumeExpectedToken(TokenKind::kw_visibility)) { + if (Visibility.has_value()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param) + << CurToken.TokKind; + return std::nullopt; + } + + if (consumeExpectedToken(TokenKind::pu_equal)) + return std::nullopt; + + Visibility = parseShaderVisibility(); + if (!Visibility.has_value()) + return std::nullopt; + } } while (tryConsumeExpectedToken(TokenKind::pu_comma)); + // Fill in optional visibility + if (Visibility.has_value()) + Table.Visibility = Visibility.value(); + if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params, /*param of=*/TokenKind::kw_DescriptorTable)) @@ -222,6 +242,32 @@ std::optional RootSignatureParser::parseRegister() { return Reg; } +std::optional +RootSignatureParser::parseShaderVisibility() { + assert(CurToken.TokKind == TokenKind::pu_equal && + "Expects to only be invoked starting at given keyword"); + + TokenKind Expected[] = { +#define SHADER_VISIBILITY_ENUM(NAME, LIT) TokenKind::en_##NAME, +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + }; + + if (!tryConsumeExpectedToken(Expected)) + return std::nullopt; + + switch (CurToken.TokKind) { +#define SHADER_VISIBILITY_ENUM(NAME, LIT) \ + case TokenKind::en_##NAME: \ + return ShaderVisibility::NAME; \ + break; +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + default: + llvm_unreachable("Switch for consumed enum token was not provided"); + } + + return std::nullopt; +} + std::optional RootSignatureParser::handleUIntLiteral() { // Parse the numeric value and do semantic checks on its specification clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc, diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index e382a1b26d366..1d89567509e72 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -131,6 +131,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { DescriptorTable( CBV(b0), SRV(space = 3, t42), + visibility = SHADER_VISIBILITY_PIXEL, Sampler(s987, space = +2), UAV(u4294967294) ), @@ -186,11 +187,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) { Elem = Elements[4]; ASSERT_TRUE(std::holds_alternative(Elem)); ASSERT_EQ(std::get(Elem).NumClauses, (uint32_t)4); + ASSERT_EQ(std::get(Elem).Visibility, + ShaderVisibility::Pixel); // Empty Descriptor Table Elem = Elements[5]; ASSERT_TRUE(std::holds_alternative(Elem)); ASSERT_EQ(std::get(Elem).NumClauses, 0u); + ASSERT_EQ(std::get(Elem).Visibility, ShaderVisibility::All); ASSERT_TRUE(Consumer->isSatisfied()); } diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 778b0c397f9cf..d51b853942dd3 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -21,6 +21,19 @@ namespace llvm { namespace hlsl { namespace rootsig { +// Definition of the various enumerations and flags + +enum class ShaderVisibility { + All = 0, + Vertex = 1, + Hull = 2, + Domain = 3, + Geometry = 4, + Pixel = 5, + Amplification = 6, + Mesh = 7, +}; + // Definitions of the in-memory data layout structures // Models the different registers: bReg | tReg | uReg | sReg @@ -32,6 +45,7 @@ struct Register { // Models the end of a descriptor table and stores its visibility struct DescriptorTable { + ShaderVisibility Visibility = ShaderVisibility::All; uint32_t NumClauses = 0; // The number of clauses in the table };