@@ -54,6 +54,15 @@ bool RootSignatureParser::parse() {
54
54
/* param of=*/ TokenKind::kw_RootSignature);
55
55
}
56
56
57
+ template <typename FlagType>
58
+ static FlagType maybeOrFlag (std::optional<FlagType> Flags, FlagType Flag) {
59
+ if (!Flags.has_value ())
60
+ return Flag;
61
+
62
+ return static_cast <FlagType>(llvm::to_underlying (Flags.value ()) |
63
+ llvm::to_underlying (Flag));
64
+ }
65
+
57
66
std::optional<RootFlags> RootSignatureParser::parseRootFlags () {
58
67
assert (CurToken.TokKind == TokenKind::kw_RootFlags &&
59
68
" Expects to only be invoked starting at given keyword" );
@@ -62,7 +71,36 @@ std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
62
71
CurToken.TokKind ))
63
72
return std::nullopt;
64
73
65
- RootFlags Flags = RootFlags::None;
74
+ std::optional<RootFlags> Flags = RootFlags::None;
75
+
76
+ // Handle the edge-case of '0' to specify no flags set
77
+ if (tryConsumeExpectedToken (TokenKind::int_literal)) {
78
+ if (!verifyZeroFlag ()) {
79
+ getDiags ().Report (CurToken.TokLoc , diag::err_expected) << " '0'" ;
80
+ return std::nullopt;
81
+ }
82
+ } else {
83
+ // Otherwise, parse as many flags as possible
84
+ TokenKind Expected[] = {
85
+ #define ROOT_FLAG_ENUM (NAME, LIT ) TokenKind::en_##NAME,
86
+ #include " clang/Lex/HLSLRootSignatureTokenKinds.def"
87
+ };
88
+
89
+ do {
90
+ if (tryConsumeExpectedToken (Expected)) {
91
+ switch (CurToken.TokKind ) {
92
+ #define ROOT_FLAG_ENUM (NAME, LIT ) \
93
+ case TokenKind::en_##NAME: \
94
+ Flags = \
95
+ maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
96
+ break ;
97
+ #include " clang/Lex/HLSLRootSignatureTokenKinds.def"
98
+ default :
99
+ llvm_unreachable (" Switch for consumed enum token was not provided" );
100
+ }
101
+ }
102
+ } while (tryConsumeExpectedToken (TokenKind::pu_or));
103
+ }
66
104
67
105
if (consumeExpectedToken (TokenKind::pu_r_paren,
68
106
diag::err_hlsl_unexpected_end_of_params,
@@ -492,15 +530,6 @@ RootSignatureParser::parseShaderVisibility() {
492
530
return std::nullopt;
493
531
}
494
532
495
- template <typename FlagType>
496
- static FlagType maybeOrFlag (std::optional<FlagType> Flags, FlagType Flag) {
497
- if (!Flags.has_value ())
498
- return Flag;
499
-
500
- return static_cast <FlagType>(llvm::to_underlying (Flags.value ()) |
501
- llvm::to_underlying (Flag));
502
- }
503
-
504
533
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
505
534
RootSignatureParser::parseDescriptorRangeFlags () {
506
535
assert (CurToken.TokKind == TokenKind::pu_equal &&
0 commit comments