@@ -27,6 +27,13 @@ RootSignatureParser::RootSignatureParser(SmallVector<RootElement> &Elements,
2727bool RootSignatureParser::parse () {
2828 // Iterate as many RootElements as possible
2929 do {
30+ if (tryConsumeExpectedToken (TokenKind::kw_RootFlags)) {
31+ auto Flags = parseRootFlags ();
32+ if (!Flags.has_value ())
33+ return true ;
34+ Elements.push_back (*Flags);
35+ }
36+
3037 if (tryConsumeExpectedToken (TokenKind::kw_RootConstants)) {
3138 auto Constants = parseRootConstants ();
3239 if (!Constants.has_value ())
@@ -47,6 +54,61 @@ bool RootSignatureParser::parse() {
4754 /* param of=*/ TokenKind::kw_RootSignature);
4855}
4956
57+ template <typename FlagType>
58+ static FlagType maybeOrFlag (std::optional<FlagType> Flags, FlagType Flag) {
59+ if (!Flags.has_value ())
60+ return Flag;
61+
62+ return static_cast <FlagType>(llvm::to_underlying (Flags.value ()) |
63+ llvm::to_underlying (Flag));
64+ }
65+
66+ std::optional<RootFlags> RootSignatureParser::parseRootFlags () {
67+ assert (CurToken.TokKind == TokenKind::kw_RootFlags &&
68+ " Expects to only be invoked starting at given keyword" );
69+
70+ if (consumeExpectedToken (TokenKind::pu_l_paren, diag::err_expected_after,
71+ CurToken.TokKind ))
72+ return std::nullopt ;
73+
74+ std::optional<RootFlags> Flags = RootFlags::None;
75+
76+ // Handle the edge-case of '0' to specify no flags set
77+ if (tryConsumeExpectedToken (TokenKind::int_literal)) {
78+ if (!verifyZeroFlag ()) {
79+ getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_non_zero_flag);
80+ return std::nullopt ;
81+ }
82+ } else {
83+ // Otherwise, parse as many flags as possible
84+ TokenKind Expected[] = {
85+ #define ROOT_FLAG_ENUM (NAME, LIT ) TokenKind::en_##NAME,
86+ #include " clang/Lex/HLSLRootSignatureTokenKinds.def"
87+ };
88+
89+ do {
90+ if (tryConsumeExpectedToken (Expected)) {
91+ switch (CurToken.TokKind ) {
92+ #define ROOT_FLAG_ENUM (NAME, LIT ) \
93+ case TokenKind::en_##NAME: \
94+ Flags = maybeOrFlag<RootFlags>(Flags, RootFlags::NAME); \
95+ break ;
96+ #include " clang/Lex/HLSLRootSignatureTokenKinds.def"
97+ default :
98+ llvm_unreachable (" Switch for consumed enum token was not provided" );
99+ }
100+ }
101+ } while (tryConsumeExpectedToken (TokenKind::pu_or));
102+ }
103+
104+ if (consumeExpectedToken (TokenKind::pu_r_paren,
105+ diag::err_hlsl_unexpected_end_of_params,
106+ /* param of=*/ TokenKind::kw_RootFlags))
107+ return std::nullopt ;
108+
109+ return Flags;
110+ }
111+
50112std::optional<RootConstants> RootSignatureParser::parseRootConstants () {
51113 assert (CurToken.TokKind == TokenKind::kw_RootConstants &&
52114 " Expects to only be invoked starting at given keyword" );
@@ -467,15 +529,6 @@ RootSignatureParser::parseShaderVisibility() {
467529 return std::nullopt ;
468530}
469531
470- template <typename FlagType>
471- static FlagType maybeOrFlag (std::optional<FlagType> Flags, FlagType Flag) {
472- if (!Flags.has_value ())
473- return Flag;
474-
475- return static_cast <FlagType>(llvm::to_underlying (Flags.value ()) |
476- llvm::to_underlying (Flag));
477- }
478-
479532std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
480533RootSignatureParser::parseDescriptorRangeFlags () {
481534 assert (CurToken.TokKind == TokenKind::pu_equal &&
@@ -484,7 +537,7 @@ RootSignatureParser::parseDescriptorRangeFlags() {
484537 // Handle the edge-case of '0' to specify no flags set
485538 if (tryConsumeExpectedToken (TokenKind::int_literal)) {
486539 if (!verifyZeroFlag ()) {
487- getDiags ().Report (CurToken.TokLoc , diag::err_expected) << " '0' " ;
540+ getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_non_zero_flag) ;
488541 return std::nullopt ;
489542 }
490543 return DescriptorRangeFlags::None;
0 commit comments