diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td index 33e9296f39eeb..3bbdc49946dac 100644 --- a/clang/include/clang/Basic/DiagnosticParseKinds.td +++ b/clang/include/clang/Basic/DiagnosticParseKinds.td @@ -1856,5 +1856,6 @@ def err_hlsl_unexpected_end_of_params def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">; def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">; def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">; +def err_hlsl_rootsig_non_zero_flag : Error<"flag value is neither a literal 0 nor a named value">; } // end of Parser diagnostics diff --git a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def index d9edeff7ac567..c6f7f8928bc91 100644 --- a/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def +++ b/clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def @@ -27,6 +27,9 @@ #endif // Defines the various types of enum +#ifndef ROOT_FLAG_ENUM +#define ROOT_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT) +#endif #ifndef UNBOUNDED_ENUM #define UNBOUNDED_ENUM(NAME, LIT) ENUM(NAME, LIT) #endif @@ -74,6 +77,7 @@ PUNCTUATOR(minus, '-') // RootElement Keywords: KEYWORD(RootSignature) // used only for diagnostic messaging +KEYWORD(RootFlags) KEYWORD(DescriptorTable) KEYWORD(RootConstants) @@ -101,6 +105,20 @@ UNBOUNDED_ENUM(unbounded, "unbounded") // Descriptor Range Offset Enum: DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND") +// Root Flag Enums: +ROOT_FLAG_ENUM(AllowInputAssemblerInputLayout, "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT") +ROOT_FLAG_ENUM(DenyVertexShaderRootAccess, "DENY_VERTEX_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyHullShaderRootAccess, "DENY_HULL_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyDomainShaderRootAccess, "DENY_DOMAIN_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyGeometryShaderRootAccess, "DENY_GEOMETRY_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyPixelShaderRootAccess, "DENY_PIXEL_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyAmplificationShaderRootAccess, "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(DenyMeshShaderRootAccess, "DENY_MESH_SHADER_ROOT_ACCESS") +ROOT_FLAG_ENUM(AllowStreamOutput, "ALLOW_STREAM_OUTPUT") +ROOT_FLAG_ENUM(LocalRootSignature, "LOCAL_ROOT_SIGNATURE") +ROOT_FLAG_ENUM(CBVSRVUAVHeapDirectlyIndexed, "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED") +ROOT_FLAG_ENUM(SamplerHeapDirectlyIndexed , "SAMPLER_HEAP_DIRECTLY_INDEXED") + // Root Descriptor Flag Enums: ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE") ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE") @@ -128,6 +146,7 @@ SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH") #undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF #undef DESCRIPTOR_RANGE_FLAG_ENUM_ON #undef ROOT_DESCRIPTOR_FLAG_ENUM +#undef ROOT_FLAG_ENUM #undef DESCRIPTOR_RANGE_OFFSET_ENUM #undef UNBOUNDED_ENUM #undef ENUM diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h index 2ac2083983741..915266f8a36ae 100644 --- a/clang/include/clang/Parse/ParseHLSLRootSignature.h +++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h @@ -71,6 +71,7 @@ class RootSignatureParser { // expected, or, there is a lexing error /// Root Element parse methods: + std::optional parseRootFlags(); std::optional parseRootConstants(); std::optional parseDescriptorTable(); std::optional diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp index 86bf30668db46..5603900429844 100644 --- a/clang/lib/Parse/ParseHLSLRootSignature.cpp +++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp @@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector &Elements, bool RootSignatureParser::parse() { // Iterate as many RootElements as possible do { + if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) { + auto Flags = parseRootFlags(); + if (!Flags.has_value()) + return true; + Elements.push_back(*Flags); + } + if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) { auto Constants = parseRootConstants(); if (!Constants.has_value()) @@ -47,6 +54,61 @@ bool RootSignatureParser::parse() { /*param of=*/TokenKind::kw_RootSignature); } +template +static FlagType maybeOrFlag(std::optional Flags, FlagType Flag) { + if (!Flags.has_value()) + return Flag; + + return static_cast(llvm::to_underlying(Flags.value()) | + llvm::to_underlying(Flag)); +} + +std::optional RootSignatureParser::parseRootFlags() { + assert(CurToken.TokKind == TokenKind::kw_RootFlags && + "Expects to only be invoked starting at given keyword"); + + if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after, + CurToken.TokKind)) + return std::nullopt; + + std::optional Flags = RootFlags::None; + + // Handle the edge-case of '0' to specify no flags set + if (tryConsumeExpectedToken(TokenKind::int_literal)) { + if (!verifyZeroFlag()) { + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag); + return std::nullopt; + } + } else { + // Otherwise, parse as many flags as possible + TokenKind Expected[] = { +#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME, +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + }; + + do { + if (tryConsumeExpectedToken(Expected)) { + switch (CurToken.TokKind) { +#define ROOT_FLAG_ENUM(NAME, LIT) \ + case TokenKind::en_##NAME: \ + Flags = maybeOrFlag(Flags, RootFlags::NAME); \ + break; +#include "clang/Lex/HLSLRootSignatureTokenKinds.def" + default: + llvm_unreachable("Switch for consumed enum token was not provided"); + } + } + } while (tryConsumeExpectedToken(TokenKind::pu_or)); + } + + if (consumeExpectedToken(TokenKind::pu_r_paren, + diag::err_hlsl_unexpected_end_of_params, + /*param of=*/TokenKind::kw_RootFlags)) + return std::nullopt; + + return Flags; +} + std::optional RootSignatureParser::parseRootConstants() { assert(CurToken.TokKind == TokenKind::kw_RootConstants && "Expects to only be invoked starting at given keyword"); @@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() { return std::nullopt; } -template -static FlagType maybeOrFlag(std::optional Flags, FlagType Flag) { - if (!Flags.has_value()) - return Flag; - - return static_cast(llvm::to_underlying(Flags.value()) | - llvm::to_underlying(Flag)); -} - std::optional RootSignatureParser::parseDescriptorRangeFlags() { assert(CurToken.TokKind == TokenKind::pu_equal && @@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() { // Handle the edge-case of '0' to specify no flags set if (tryConsumeExpectedToken(TokenKind::int_literal)) { if (!verifyZeroFlag()) { - getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'"; + getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag); return std::nullopt; } return DescriptorRangeFlags::None; diff --git a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp index a761257149c11..1f8d8be64e323 100644 --- a/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp +++ b/clang/unittests/Lex/LexHLSLRootSignatureTest.cpp @@ -128,7 +128,7 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) { RootSignature - DescriptorTable RootConstants + RootFlags DescriptorTable RootConstants num32BitConstants @@ -139,6 +139,19 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) { unbounded DESCRIPTOR_RANGE_OFFSET_APPEND + allow_input_assembler_input_layout + deny_vertex_shader_root_access + deny_hull_shader_root_access + deny_domain_shader_root_access + deny_geometry_shader_root_access + deny_pixel_shader_root_access + deny_amplification_shader_root_access + deny_mesh_shader_root_access + allow_stream_output + local_root_signature + cbv_srv_uav_heap_directly_indexed + sampler_heap_directly_indexed + DATA_VOLATILE DATA_STATIC_WHILE_SET_AT_EXECUTE DATA_STATIC diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp index eda68531de34f..c97f8d0b392d1 100644 --- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp +++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp @@ -294,6 +294,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) { ASSERT_TRUE(Consumer->isSatisfied()); } +TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) { + const llvm::StringLiteral Source = R"cc( + RootFlags(), + RootFlags(0), + RootFlags( + deny_domain_shader_root_access | + deny_pixel_shader_root_access | + local_root_signature | + cbv_srv_uav_heap_directly_indexed | + deny_amplification_shader_root_access | + deny_geometry_shader_root_access | + deny_hull_shader_root_access | + deny_mesh_shader_root_access | + allow_stream_output | + sampler_heap_directly_indexed | + allow_input_assembler_input_layout | + deny_vertex_shader_root_access + ) + )cc"; + + TrivialModuleLoader ModLoader; + auto PP = createPP(Source, ModLoader); + auto TokLoc = SourceLocation(); + + hlsl::RootSignatureLexer Lexer(Source, TokLoc); + SmallVector Elements; + hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); + + // Test no diagnostics produced + Consumer->setNoDiag(); + + ASSERT_FALSE(Parser.parse()); + + ASSERT_EQ(Elements.size(), 3u); + + RootElement Elem = Elements[0]; + ASSERT_TRUE(std::holds_alternative(Elem)); + ASSERT_EQ(std::get(Elem), RootFlags::None); + + Elem = Elements[1]; + ASSERT_TRUE(std::holds_alternative(Elem)); + ASSERT_EQ(std::get(Elem), RootFlags::None); + + Elem = Elements[2]; + ASSERT_TRUE(std::holds_alternative(Elem)); + ASSERT_EQ(std::get(Elem), RootFlags::ValidFlags); + + ASSERT_TRUE(Consumer->isSatisfied()); +} + TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) { // This test will checks we can handling trailing commas ',' const llvm::StringLiteral Source = R"cc( @@ -566,7 +616,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) { hlsl::RootSignatureParser Parser(Elements, Lexer, *PP); // Test correct diagnostic produced - Consumer->setExpected(diag::err_expected); + Consumer->setExpected(diag::err_hlsl_rootsig_non_zero_flag); ASSERT_TRUE(Parser.parse()); ASSERT_TRUE(Consumer->isSatisfied()); diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h index 8b8324df18bb3..2ecaf69fc2f9c 100644 --- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h +++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h @@ -23,6 +23,23 @@ namespace rootsig { // Definition of the various enumerations and flags +enum class RootFlags : uint32_t { + None = 0, + AllowInputAssemblerInputLayout = 0x1, + DenyVertexShaderRootAccess = 0x2, + DenyHullShaderRootAccess = 0x4, + DenyDomainShaderRootAccess = 0x8, + DenyGeometryShaderRootAccess = 0x10, + DenyPixelShaderRootAccess = 0x20, + AllowStreamOutput = 0x40, + LocalRootSignature = 0x80, + DenyAmplificationShaderRootAccess = 0x100, + DenyMeshShaderRootAccess = 0x200, + CBVSRVUAVHeapDirectlyIndexed = 0x400, + SamplerHeapDirectlyIndexed = 0x800, + ValidFlags = 0x00000fff +}; + enum class DescriptorRangeFlags : unsigned { None = 0, DescriptorsVolatile = 0x1, @@ -97,8 +114,8 @@ struct DescriptorTableClause { }; // Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause -using RootElement = - std::variant; +using RootElement = std::variant; } // namespace rootsig } // namespace hlsl