Skip to content

[HLSL][RootSignature] Add parsing for RootFlags #138055

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/Basic/DiagnosticParseKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -1856,5 +1856,6 @@ def err_hlsl_unexpected_end_of_params
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
def err_hlsl_rootsig_non_zero_flag : Error<"flag value is neither a literal 0 nor a named value">;

} // end of Parser diagnostics
19 changes: 19 additions & 0 deletions clang/include/clang/Lex/HLSLRootSignatureTokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#endif

// Defines the various types of enum
#ifndef ROOT_FLAG_ENUM
#define ROOT_FLAG_ENUM(NAME, LIT) ENUM(NAME, LIT)
#endif
#ifndef UNBOUNDED_ENUM
#define UNBOUNDED_ENUM(NAME, LIT) ENUM(NAME, LIT)
#endif
Expand Down Expand Up @@ -74,6 +77,7 @@ PUNCTUATOR(minus, '-')

// RootElement Keywords:
KEYWORD(RootSignature) // used only for diagnostic messaging
KEYWORD(RootFlags)
KEYWORD(DescriptorTable)
KEYWORD(RootConstants)

Expand Down Expand Up @@ -101,6 +105,20 @@ UNBOUNDED_ENUM(unbounded, "unbounded")
// Descriptor Range Offset Enum:
DESCRIPTOR_RANGE_OFFSET_ENUM(DescriptorRangeOffsetAppend, "DESCRIPTOR_RANGE_OFFSET_APPEND")

// Root Flag Enums:
ROOT_FLAG_ENUM(AllowInputAssemblerInputLayout, "ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT")
ROOT_FLAG_ENUM(DenyVertexShaderRootAccess, "DENY_VERTEX_SHADER_ROOT_ACCESS")
ROOT_FLAG_ENUM(DenyHullShaderRootAccess, "DENY_HULL_SHADER_ROOT_ACCESS")
ROOT_FLAG_ENUM(DenyDomainShaderRootAccess, "DENY_DOMAIN_SHADER_ROOT_ACCESS")
ROOT_FLAG_ENUM(DenyGeometryShaderRootAccess, "DENY_GEOMETRY_SHADER_ROOT_ACCESS")
ROOT_FLAG_ENUM(DenyPixelShaderRootAccess, "DENY_PIXEL_SHADER_ROOT_ACCESS")
ROOT_FLAG_ENUM(DenyAmplificationShaderRootAccess, "DENY_AMPLIFICATION_SHADER_ROOT_ACCESS")
ROOT_FLAG_ENUM(DenyMeshShaderRootAccess, "DENY_MESH_SHADER_ROOT_ACCESS")
ROOT_FLAG_ENUM(AllowStreamOutput, "ALLOW_STREAM_OUTPUT")
ROOT_FLAG_ENUM(LocalRootSignature, "LOCAL_ROOT_SIGNATURE")
ROOT_FLAG_ENUM(CBVSRVUAVHeapDirectlyIndexed, "CBV_SRV_UAV_HEAP_DIRECTLY_INDEXED")
ROOT_FLAG_ENUM(SamplerHeapDirectlyIndexed , "SAMPLER_HEAP_DIRECTLY_INDEXED")

// Root Descriptor Flag Enums:
ROOT_DESCRIPTOR_FLAG_ENUM(DataVolatile, "DATA_VOLATILE")
ROOT_DESCRIPTOR_FLAG_ENUM(DataStaticWhileSetAtExecute, "DATA_STATIC_WHILE_SET_AT_EXECUTE")
Expand Down Expand Up @@ -128,6 +146,7 @@ SHADER_VISIBILITY_ENUM(Mesh, "SHADER_VISIBILITY_MESH")
#undef DESCRIPTOR_RANGE_FLAG_ENUM_OFF
#undef DESCRIPTOR_RANGE_FLAG_ENUM_ON
#undef ROOT_DESCRIPTOR_FLAG_ENUM
#undef ROOT_FLAG_ENUM
#undef DESCRIPTOR_RANGE_OFFSET_ENUM
#undef UNBOUNDED_ENUM
#undef ENUM
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class RootSignatureParser {
// expected, or, there is a lexing error

/// Root Element parse methods:
std::optional<llvm::hlsl::rootsig::RootFlags> parseRootFlags();
std::optional<llvm::hlsl::rootsig::RootConstants> parseRootConstants();
std::optional<llvm::hlsl::rootsig::DescriptorTable> parseDescriptorTable();
std::optional<llvm::hlsl::rootsig::DescriptorTableClause>
Expand Down
73 changes: 63 additions & 10 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
bool RootSignatureParser::parse() {
// Iterate as many RootElements as possible
do {
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
auto Flags = parseRootFlags();
if (!Flags.has_value())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (!Flags.has_value())
if (!Flags)

return true;
Elements.push_back(*Flags);
}

if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
auto Constants = parseRootConstants();
if (!Constants.has_value())
Expand All @@ -47,6 +54,61 @@ bool RootSignatureParser::parse() {
/*param of=*/TokenKind::kw_RootSignature);
}

template <typename FlagType>
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
if (!Flags.has_value())
return Flag;

return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
llvm::to_underlying(Flag));
}

std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
"Expects to only be invoked starting at given keyword");

if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
CurToken.TokKind))
return std::nullopt;

std::optional<RootFlags> Flags = RootFlags::None;

// Handle the edge-case of '0' to specify no flags set
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
if (!verifyZeroFlag()) {
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
return std::nullopt;
}
} else {
// Otherwise, parse as many flags as possible
TokenKind Expected[] = {
#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
};

do {
if (tryConsumeExpectedToken(Expected)) {
switch (CurToken.TokKind) {
#define ROOT_FLAG_ENUM(NAME, LIT) \
case TokenKind::en_##NAME: \
Flags = maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
break;
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
default:
llvm_unreachable("Switch for consumed enum token was not provided");
}
}
} while (tryConsumeExpectedToken(TokenKind::pu_or));
}

if (consumeExpectedToken(TokenKind::pu_r_paren,
diag::err_hlsl_unexpected_end_of_params,
/*param of=*/TokenKind::kw_RootFlags))
return std::nullopt;

return Flags;
}

std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
"Expects to only be invoked starting at given keyword");
Expand Down Expand Up @@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() {
return std::nullopt;
}

template <typename FlagType>
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
if (!Flags.has_value())
return Flag;

return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
llvm::to_underlying(Flag));
}

std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
RootSignatureParser::parseDescriptorRangeFlags() {
assert(CurToken.TokKind == TokenKind::pu_equal &&
Expand All @@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() {
// Handle the edge-case of '0' to specify no flags set
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
if (!verifyZeroFlag()) {
getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
return std::nullopt;
}
return DescriptorRangeFlags::None;
Expand Down
15 changes: 14 additions & 1 deletion clang/unittests/Lex/LexHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
RootSignature
DescriptorTable RootConstants
RootFlags DescriptorTable RootConstants
num32BitConstants
Expand All @@ -139,6 +139,19 @@ TEST_F(LexHLSLRootSignatureTest, ValidLexAllTokensTest) {
unbounded
DESCRIPTOR_RANGE_OFFSET_APPEND
allow_input_assembler_input_layout
deny_vertex_shader_root_access
deny_hull_shader_root_access
deny_domain_shader_root_access
deny_geometry_shader_root_access
deny_pixel_shader_root_access
deny_amplification_shader_root_access
deny_mesh_shader_root_access
allow_stream_output
local_root_signature
cbv_srv_uav_heap_directly_indexed
sampler_heap_directly_indexed
DATA_VOLATILE
DATA_STATIC_WHILE_SET_AT_EXECUTE
DATA_STATIC
Expand Down
52 changes: 51 additions & 1 deletion clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,56 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
const llvm::StringLiteral Source = R"cc(
RootFlags(),
RootFlags(0),
RootFlags(
deny_domain_shader_root_access |
deny_pixel_shader_root_access |
local_root_signature |
cbv_srv_uav_heap_directly_indexed |
deny_amplification_shader_root_access |
deny_geometry_shader_root_access |
deny_hull_shader_root_access |
deny_mesh_shader_root_access |
allow_stream_output |
sampler_heap_directly_indexed |
allow_input_assembler_input_layout |
deny_vertex_shader_root_access
)
)cc";

TrivialModuleLoader ModLoader;
auto PP = createPP(Source, ModLoader);
auto TokLoc = SourceLocation();

hlsl::RootSignatureLexer Lexer(Source, TokLoc);
SmallVector<RootElement> Elements;
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);

// Test no diagnostics produced
Consumer->setNoDiag();

ASSERT_FALSE(Parser.parse());

ASSERT_EQ(Elements.size(), 3u);

RootElement Elem = Elements[0];
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);

Elem = Elements[1];
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);

Elem = Elements[2];
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::ValidFlags);

ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
// This test will checks we can handling trailing commas ','
const llvm::StringLiteral Source = R"cc(
Expand Down Expand Up @@ -566,7 +616,7 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidNonZeroFlagsTest) {
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);

// Test correct diagnostic produced
Consumer->setExpected(diag::err_expected);
Consumer->setExpected(diag::err_hlsl_rootsig_non_zero_flag);
ASSERT_TRUE(Parser.parse());

ASSERT_TRUE(Consumer->isSatisfied());
Expand Down
21 changes: 19 additions & 2 deletions llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ namespace rootsig {

// Definition of the various enumerations and flags

enum class RootFlags : uint32_t {
None = 0,
AllowInputAssemblerInputLayout = 0x1,
DenyVertexShaderRootAccess = 0x2,
DenyHullShaderRootAccess = 0x4,
DenyDomainShaderRootAccess = 0x8,
DenyGeometryShaderRootAccess = 0x10,
DenyPixelShaderRootAccess = 0x20,
AllowStreamOutput = 0x40,
LocalRootSignature = 0x80,
DenyAmplificationShaderRootAccess = 0x100,
DenyMeshShaderRootAccess = 0x200,
CBVSRVUAVHeapDirectlyIndexed = 0x400,
SamplerHeapDirectlyIndexed = 0x800,
ValidFlags = 0x00000fff
};

enum class DescriptorRangeFlags : unsigned {
None = 0,
DescriptorsVolatile = 0x1,
Expand Down Expand Up @@ -97,8 +114,8 @@ struct DescriptorTableClause {
};

// Models RootElement : RootConstants | DescriptorTable | DescriptorTableClause
using RootElement =
std::variant<RootConstants, DescriptorTable, DescriptorTableClause>;
using RootElement = std::variant<RootFlags, RootConstants, DescriptorTable,
DescriptorTableClause>;

} // namespace rootsig
} // namespace hlsl
Expand Down
Loading