@@ -54,6 +54,15 @@ bool RootSignatureParser::parse() {
5454 /* param of=*/ TokenKind::kw_RootSignature);
5555}
5656
57+ template <typename FlagType>
58+ static FlagType maybeOrFlag (std::optional<FlagType> Flags, FlagType Flag) {
59+ if (!Flags.has_value ())
60+ return Flag;
61+
62+ return static_cast <FlagType>(llvm::to_underlying (Flags.value ()) |
63+ llvm::to_underlying (Flag));
64+ }
65+
5766std::optional<RootFlags> RootSignatureParser::parseRootFlags () {
5867 assert (CurToken.TokKind == TokenKind::kw_RootFlags &&
5968 " Expects to only be invoked starting at given keyword" );
@@ -62,7 +71,36 @@ std::optional<RootFlags> RootSignatureParser::parseRootFlags() {
6271 CurToken.TokKind ))
6372 return std::nullopt ;
6473
65- RootFlags Flags = RootFlags::None;
74+ std::optional<RootFlags> Flags = RootFlags::None;
75+
76+ // Handle the edge-case of '0' to specify no flags set
77+ if (tryConsumeExpectedToken (TokenKind::int_literal)) {
78+ if (!verifyZeroFlag ()) {
79+ getDiags ().Report (CurToken.TokLoc , diag::err_expected) << " '0'" ;
80+ return std::nullopt ;
81+ }
82+ } else {
83+ // Otherwise, parse as many flags as possible
84+ TokenKind Expected[] = {
85+ #define ROOT_FLAG_ENUM (NAME, LIT ) TokenKind::en_##NAME,
86+ #include " clang/Lex/HLSLRootSignatureTokenKinds.def"
87+ };
88+
89+ do {
90+ if (tryConsumeExpectedToken (Expected)) {
91+ switch (CurToken.TokKind ) {
92+ #define ROOT_FLAG_ENUM (NAME, LIT ) \
93+ case TokenKind::en_##NAME: \
94+ Flags = \
95+ maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
96+ break ;
97+ #include " clang/Lex/HLSLRootSignatureTokenKinds.def"
98+ default :
99+ llvm_unreachable (" Switch for consumed enum token was not provided" );
100+ }
101+ }
102+ } while (tryConsumeExpectedToken (TokenKind::pu_or));
103+ }
66104
67105 if (consumeExpectedToken (TokenKind::pu_r_paren,
68106 diag::err_hlsl_unexpected_end_of_params,
@@ -492,15 +530,6 @@ RootSignatureParser::parseShaderVisibility() {
492530 return std::nullopt ;
493531}
494532
495- template <typename FlagType>
496- static FlagType maybeOrFlag (std::optional<FlagType> Flags, FlagType Flag) {
497- if (!Flags.has_value ())
498- return Flag;
499-
500- return static_cast <FlagType>(llvm::to_underlying (Flags.value ()) |
501- llvm::to_underlying (Flag));
502- }
503-
504533std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
505534RootSignatureParser::parseDescriptorRangeFlags () {
506535 assert (CurToken.TokKind == TokenKind::pu_equal &&
0 commit comments