Skip to content

Commit b471274

Browse files
committed
implement parsing of the flags
1 parent 2f8222e commit b471274

File tree

2 files changed

+64
-12
lines changed

2 files changed

+64
-12
lines changed

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ bool RootSignatureParser::parse() {
5454
/*param of=*/TokenKind::kw_RootSignature);
5555
}
5656

57+
template <typename FlagType>
58+
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
59+
if (!Flags.has_value())
60+
return Flag;
61+
62+
return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
63+
llvm::to_underlying(Flag));
64+
}
65+
5766
std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
5867
assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
5968
"Expects to only be invoked starting at given keyword");
@@ -62,7 +71,36 @@ std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
6271
CurToken.TokKind))
6372
return std::nullopt;
6473

65-
RootFlags Flags = RootFlags::None;
74+
std::optional<RootFlags> Flags = RootFlags::None;
75+
76+
// Handle the edge-case of '0' to specify no flags set
77+
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
78+
if (!verifyZeroFlag()) {
79+
getDiags().Report(CurToken.TokLoc, diag::err_expected) << "'0'";
80+
return std::nullopt;
81+
}
82+
} else {
83+
// Otherwise, parse as many flags as possible
84+
TokenKind Expected[] = {
85+
#define ROOT_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
86+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
87+
};
88+
89+
do {
90+
if (tryConsumeExpectedToken(Expected)) {
91+
switch (CurToken.TokKind) {
92+
#define ROOT_FLAG_ENUM(NAME, LIT) \
93+
case TokenKind::en_##NAME: \
94+
Flags = \
95+
maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
96+
break;
97+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
98+
default:
99+
llvm_unreachable("Switch for consumed enum token was not provided");
100+
}
101+
}
102+
} while (tryConsumeExpectedToken(TokenKind::pu_or));
103+
}
66104

67105
if (consumeExpectedToken(TokenKind::pu_r_paren,
68106
diag::err_hlsl_unexpected_end_of_params,
@@ -492,15 +530,6 @@ RootSignatureParser::parseShaderVisibility() {
492530
return std::nullopt;
493531
}
494532

495-
template <typename FlagType>
496-
static FlagType maybeOrFlag(std::optional<FlagType> Flags, FlagType Flag) {
497-
if (!Flags.has_value())
498-
return Flag;
499-
500-
return static_cast<FlagType>(llvm::to_underlying(Flags.value()) |
501-
llvm::to_underlying(Flag));
502-
}
503-
504533
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
505534
RootSignatureParser::parseDescriptorRangeFlags() {
506535
assert(CurToken.TokKind == TokenKind::pu_equal &&

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,22 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
296296

297297
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
298298
const llvm::StringLiteral Source = R"cc(
299-
RootFlags()
299+
RootFlags(),
300+
RootFlags(0),
301+
RootFlags(
302+
deny_domain_shader_root_access |
303+
deny_pixel_shader_root_access |
304+
local_root_signature |
305+
cbv_srv_uav_heap_directly_indexed |
306+
deny_amplification_shader_root_access |
307+
deny_geometry_shader_root_access |
308+
deny_hull_shader_root_access |
309+
deny_mesh_shader_root_access |
310+
allow_stream_output |
311+
sampler_heap_directly_indexed |
312+
allow_input_assembler_input_layout |
313+
deny_vertex_shader_root_access
314+
)
300315
)cc";
301316

302317
TrivialModuleLoader ModLoader;
@@ -312,12 +327,20 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
312327

313328
ASSERT_FALSE(Parser.parse());
314329

315-
ASSERT_EQ(Elements.size(), 1u);
330+
ASSERT_EQ(Elements.size(), 3u);
316331

317332
RootElement Elem = Elements[0];
318333
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
319334
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
320335

336+
Elem = Elements[1];
337+
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
338+
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
339+
340+
Elem = Elements[2];
341+
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
342+
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::ValidFlags);
343+
321344
ASSERT_TRUE(Consumer->isSatisfied());
322345
}
323346

0 commit comments

Comments
 (0)