@@ -117,16 +117,26 @@ bool RootSignatureParser::parseDescriptorTableClause() {
117117 ParamKind))
118118 return true ;
119119
120- llvm::SmallDenseMap<TokenKind, ParamType> Params = {
121- {ExpectedRegister, &Clause.Register },
122- {TokenKind::kw_space, &Clause.Space },
123- };
124- llvm::SmallDenseSet<TokenKind> Mandatory = {
120+ TokenKind Keywords[2 ] = {
125121 ExpectedRegister,
122+ TokenKind::kw_space,
126123 };
124+ ParsedParamState Params (Keywords, ParamKind, CurToken.TokLoc );
125+ if (parseParams (Params))
126+ return true ;
127127
128- if (parseParams (Params, Mandatory))
128+ // Mandatory parameters:
129+ if (!Params.checkAndClearSeen (ExpectedRegister)) {
130+ getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_missing_param)
131+ << ExpectedRegister;
129132 return true ;
133+ }
134+
135+ Clause.Register = Params.Register ;
136+
137+ // Optional parameters:
138+ if (Params.checkAndClearSeen (TokenKind::kw_space))
139+ Clause.Space = Params.Space ;
130140
131141 if (consumeExpectedToken (TokenKind::pu_r_paren,
132142 diag::err_hlsl_unexpected_end_of_params,
@@ -137,68 +147,71 @@ bool RootSignatureParser::parseDescriptorTableClause() {
137147 return false ;
138148}
139149
140- // Helper struct defined to use the overloaded notation of std::visit.
141- template <class ... Ts> struct ParseParamTypeMethods : Ts... {
142- using Ts::operator ()...;
143- };
144- template <class ... Ts>
145- ParseParamTypeMethods (Ts...) -> ParseParamTypeMethods<Ts...>;
146-
147- bool RootSignatureParser::parseParam (ParamType Ref) {
148- return std::visit (
149- ParseParamTypeMethods{
150- [this ](Register *X) -> bool { return parseRegister (X); },
151- [this ](uint32_t *X) -> bool {
152- return consumeExpectedToken (TokenKind::pu_equal,
153- diag::err_expected_after,
154- CurToken.TokKind ) ||
155- parseUIntParam (X);
156- },
157- },
158- Ref);
150+ size_t RootSignatureParser::ParsedParamState::getKeywordIdx (
151+ RootSignatureToken::Kind Keyword) {
152+ ArrayRef KeywordRef = Keywords;
153+ auto It = llvm::find (KeywordRef, Keyword);
154+ assert (It != KeywordRef.end () && " Did not provide a valid param keyword" );
155+ return std::distance (KeywordRef.begin (), It);
159156}
160157
161- bool RootSignatureParser::parseParams (
162- llvm::SmallDenseMap<TokenKind, ParamType> &Params,
163- llvm::SmallDenseSet<TokenKind> &Mandatory) {
158+ bool RootSignatureParser::ParsedParamState::checkAndSetSeen (
159+ RootSignatureToken::Kind Keyword) {
160+ size_t Idx = getKeywordIdx (Keyword);
161+ bool WasSeen = Seen & (1 << Idx);
162+ Seen |= 1u << Idx;
163+ return WasSeen;
164+ }
164165
165- // Initialize a vector of possible keywords
166- SmallVector<TokenKind> Keywords;
167- for (auto Pair : Params)
168- Keywords.push_back (Pair.first );
166+ bool RootSignatureParser::ParsedParamState::checkAndClearSeen (
167+ RootSignatureToken::Kind Keyword) {
168+ size_t Idx = getKeywordIdx (Keyword);
169+ bool WasSeen = Seen & (1 << Idx);
170+ Seen &= ~(1u << Idx);
171+ return WasSeen;
172+ }
169173
170- // Keep track of which keywords have been seen to report duplicates
171- llvm::SmallDenseSet<TokenKind> Seen;
174+ bool RootSignatureParser::parseParam (ParsedParamState &Params) {
175+ TokenKind Keyword = CurToken.TokKind ;
176+ if (Keyword == TokenKind::bReg || Keyword == TokenKind::tReg ||
177+ Keyword == TokenKind::uReg || Keyword == TokenKind::sReg ) {
178+ return parseRegister (Params.Register );
179+ }
180+
181+ if (consumeExpectedToken (TokenKind::pu_equal, diag::err_expected_after,
182+ Keyword))
183+ return true ;
172184
173- while (tryConsumeExpectedToken (Keywords)) {
174- if (Seen.contains (CurToken.TokKind )) {
185+ switch (Keyword) {
186+ case RootSignatureToken::Kind::kw_space:
187+ return parseUIntParam (Params.Space );
188+ default :
189+ llvm_unreachable (" Switch for consumed keyword was not provided" );
190+ }
191+ }
192+
193+ bool RootSignatureParser::parseParams (ParsedParamState &Params) {
194+ assert (CurToken.TokKind == TokenKind::pu_l_paren &&
195+ " Expects to only be invoked starting at given token" );
196+
197+ while (tryConsumeExpectedToken (Params.Keywords )) {
198+ if (Params.checkAndSetSeen (CurToken.TokKind )) {
175199 getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_repeat_param)
176200 << CurToken.TokKind ;
177201 return true ;
178202 }
179- Seen.insert (CurToken.TokKind );
180203
181- if (parseParam (Params[CurToken. TokKind ] ))
204+ if (parseParam (Params))
182205 return true ;
183206
184207 if (!tryConsumeExpectedToken (TokenKind::pu_comma))
185208 break ;
186209 }
187210
188- bool AllMandatoryDefined = true ;
189- for (auto Kind : Mandatory) {
190- bool SeenParam = Seen.contains (Kind);
191- if (!SeenParam) {
192- getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_missing_param)
193- << Kind;
194- }
195- AllMandatoryDefined &= SeenParam;
196- }
197-
198- return !AllMandatoryDefined;
211+ return false ;
199212}
200213
201- bool RootSignatureParser::parseUIntParam (uint32_t * X) {
214+ bool RootSignatureParser::parseUIntParam (uint32_t & X) {
202215 assert (CurToken.TokKind == TokenKind::pu_equal &&
203216 " Expects to only be invoked starting at given keyword" );
204217 tryConsumeExpectedToken (TokenKind::pu_plus);
@@ -207,7 +220,7 @@ bool RootSignatureParser::parseUIntParam(uint32_t *X) {
207220 handleUIntLiteral (X);
208221}
209222
210- bool RootSignatureParser::parseRegister (Register * Register) {
223+ bool RootSignatureParser::parseRegister (Register & Register) {
211224 assert ((CurToken.TokKind == TokenKind::bReg ||
212225 CurToken.TokKind == TokenKind::tReg ||
213226 CurToken.TokKind == TokenKind::uReg ||
@@ -218,26 +231,26 @@ bool RootSignatureParser::parseRegister(Register *Register) {
218231 default :
219232 llvm_unreachable (" Switch for consumed token was not provided" );
220233 case TokenKind::bReg:
221- Register-> ViewType = RegisterType::BReg;
234+ Register. ViewType = RegisterType::BReg;
222235 break ;
223236 case TokenKind::tReg:
224- Register-> ViewType = RegisterType::TReg;
237+ Register. ViewType = RegisterType::TReg;
225238 break ;
226239 case TokenKind::uReg:
227- Register-> ViewType = RegisterType::UReg;
240+ Register. ViewType = RegisterType::UReg;
228241 break ;
229242 case TokenKind::sReg :
230- Register-> ViewType = RegisterType::SReg;
243+ Register. ViewType = RegisterType::SReg;
231244 break ;
232245 }
233246
234- if (handleUIntLiteral (& Register-> Number ))
247+ if (handleUIntLiteral (Register. Number ))
235248 return true ; // propogate NumericLiteralParser error
236249
237250 return false ;
238251}
239252
240- bool RootSignatureParser::handleUIntLiteral (uint32_t * X) {
253+ bool RootSignatureParser::handleUIntLiteral (uint32_t & X) {
241254 // Parse the numeric value and do semantic checks on its specification
242255 clang::NumericLiteralParser Literal (CurToken.NumSpelling , CurToken.TokLoc ,
243256 PP.getSourceManager (), PP.getLangOpts (),
@@ -256,7 +269,7 @@ bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
256269 return true ;
257270 }
258271
259- * X = Val.getExtValue ();
272+ X = Val.getExtValue ();
260273 return false ;
261274}
262275
0 commit comments