@@ -117,16 +117,26 @@ bool RootSignatureParser::parseDescriptorTableClause() {
117
117
ParamKind))
118
118
return true ;
119
119
120
- llvm::SmallDenseMap<TokenKind, ParamType> Params = {
121
- {ExpectedRegister, &Clause.Register },
122
- {TokenKind::kw_space, &Clause.Space },
123
- };
124
- llvm::SmallDenseSet<TokenKind> Mandatory = {
120
+ TokenKind Keywords[2 ] = {
125
121
ExpectedRegister,
122
+ TokenKind::kw_space,
126
123
};
124
+ ParsedParamState Params (Keywords, ParamKind, CurToken.TokLoc );
125
+ if (parseParams (Params))
126
+ return true ;
127
127
128
- if (parseParams (Params, Mandatory))
128
+ // Mandatory parameters:
129
+ if (!Params.checkAndClearSeen (ExpectedRegister)) {
130
+ getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_missing_param)
131
+ << ExpectedRegister;
129
132
return true ;
133
+ }
134
+
135
+ Clause.Register = Params.Register ;
136
+
137
+ // Optional parameters:
138
+ if (Params.checkAndClearSeen (TokenKind::kw_space))
139
+ Clause.Space = Params.Space ;
130
140
131
141
if (consumeExpectedToken (TokenKind::pu_r_paren,
132
142
diag::err_hlsl_unexpected_end_of_params,
@@ -137,68 +147,71 @@ bool RootSignatureParser::parseDescriptorTableClause() {
137
147
return false ;
138
148
}
139
149
140
- // Helper struct defined to use the overloaded notation of std::visit.
141
- template <class ... Ts> struct ParseParamTypeMethods : Ts... {
142
- using Ts::operator ()...;
143
- };
144
- template <class ... Ts>
145
- ParseParamTypeMethods (Ts...) -> ParseParamTypeMethods<Ts...>;
146
-
147
- bool RootSignatureParser::parseParam (ParamType Ref) {
148
- return std::visit (
149
- ParseParamTypeMethods{
150
- [this ](Register *X) -> bool { return parseRegister (X); },
151
- [this ](uint32_t *X) -> bool {
152
- return consumeExpectedToken (TokenKind::pu_equal,
153
- diag::err_expected_after,
154
- CurToken.TokKind ) ||
155
- parseUIntParam (X);
156
- },
157
- },
158
- Ref);
150
+ size_t RootSignatureParser::ParsedParamState::getKeywordIdx (
151
+ RootSignatureToken::Kind Keyword) {
152
+ ArrayRef KeywordRef = Keywords;
153
+ auto It = llvm::find (KeywordRef, Keyword);
154
+ assert (It != KeywordRef.end () && " Did not provide a valid param keyword" );
155
+ return std::distance (KeywordRef.begin (), It);
159
156
}
160
157
161
- bool RootSignatureParser::parseParams (
162
- llvm::SmallDenseMap<TokenKind, ParamType> &Params,
163
- llvm::SmallDenseSet<TokenKind> &Mandatory) {
158
+ bool RootSignatureParser::ParsedParamState::checkAndSetSeen (
159
+ RootSignatureToken::Kind Keyword) {
160
+ size_t Idx = getKeywordIdx (Keyword);
161
+ bool WasSeen = Seen & (1 << Idx);
162
+ Seen |= 1u << Idx;
163
+ return WasSeen;
164
+ }
164
165
165
- // Initialize a vector of possible keywords
166
- SmallVector<TokenKind> Keywords;
167
- for (auto Pair : Params)
168
- Keywords.push_back (Pair.first );
166
+ bool RootSignatureParser::ParsedParamState::checkAndClearSeen (
167
+ RootSignatureToken::Kind Keyword) {
168
+ size_t Idx = getKeywordIdx (Keyword);
169
+ bool WasSeen = Seen & (1 << Idx);
170
+ Seen &= ~(1u << Idx);
171
+ return WasSeen;
172
+ }
169
173
170
- // Keep track of which keywords have been seen to report duplicates
171
- llvm::SmallDenseSet<TokenKind> Seen;
174
+ bool RootSignatureParser::parseParam (ParsedParamState &Params) {
175
+ TokenKind Keyword = CurToken.TokKind ;
176
+ if (Keyword == TokenKind::bReg || Keyword == TokenKind::tReg ||
177
+ Keyword == TokenKind::uReg || Keyword == TokenKind::sReg ) {
178
+ return parseRegister (Params.Register );
179
+ }
180
+
181
+ if (consumeExpectedToken (TokenKind::pu_equal, diag::err_expected_after,
182
+ Keyword))
183
+ return true ;
172
184
173
- while (tryConsumeExpectedToken (Keywords)) {
174
- if (Seen.contains (CurToken.TokKind )) {
185
+ switch (Keyword) {
186
+ case RootSignatureToken::Kind::kw_space:
187
+ return parseUIntParam (Params.Space );
188
+ default :
189
+ llvm_unreachable (" Switch for consumed keyword was not provided" );
190
+ }
191
+ }
192
+
193
+ bool RootSignatureParser::parseParams (ParsedParamState &Params) {
194
+ assert (CurToken.TokKind == TokenKind::pu_l_paren &&
195
+ " Expects to only be invoked starting at given token" );
196
+
197
+ while (tryConsumeExpectedToken (Params.Keywords )) {
198
+ if (Params.checkAndSetSeen (CurToken.TokKind )) {
175
199
getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_repeat_param)
176
200
<< CurToken.TokKind ;
177
201
return true ;
178
202
}
179
- Seen.insert (CurToken.TokKind );
180
203
181
- if (parseParam (Params[CurToken. TokKind ] ))
204
+ if (parseParam (Params))
182
205
return true ;
183
206
184
207
if (!tryConsumeExpectedToken (TokenKind::pu_comma))
185
208
break ;
186
209
}
187
210
188
- bool AllMandatoryDefined = true ;
189
- for (auto Kind : Mandatory) {
190
- bool SeenParam = Seen.contains (Kind);
191
- if (!SeenParam) {
192
- getDiags ().Report (CurToken.TokLoc , diag::err_hlsl_rootsig_missing_param)
193
- << Kind;
194
- }
195
- AllMandatoryDefined &= SeenParam;
196
- }
197
-
198
- return !AllMandatoryDefined;
211
+ return false ;
199
212
}
200
213
201
- bool RootSignatureParser::parseUIntParam (uint32_t * X) {
214
+ bool RootSignatureParser::parseUIntParam (uint32_t & X) {
202
215
assert (CurToken.TokKind == TokenKind::pu_equal &&
203
216
" Expects to only be invoked starting at given keyword" );
204
217
tryConsumeExpectedToken (TokenKind::pu_plus);
@@ -207,7 +220,7 @@ bool RootSignatureParser::parseUIntParam(uint32_t *X) {
207
220
handleUIntLiteral (X);
208
221
}
209
222
210
- bool RootSignatureParser::parseRegister (Register * Register) {
223
+ bool RootSignatureParser::parseRegister (Register & Register) {
211
224
assert ((CurToken.TokKind == TokenKind::bReg ||
212
225
CurToken.TokKind == TokenKind::tReg ||
213
226
CurToken.TokKind == TokenKind::uReg ||
@@ -218,26 +231,26 @@ bool RootSignatureParser::parseRegister(Register *Register) {
218
231
default :
219
232
llvm_unreachable (" Switch for consumed token was not provided" );
220
233
case TokenKind::bReg:
221
- Register-> ViewType = RegisterType::BReg;
234
+ Register. ViewType = RegisterType::BReg;
222
235
break ;
223
236
case TokenKind::tReg:
224
- Register-> ViewType = RegisterType::TReg;
237
+ Register. ViewType = RegisterType::TReg;
225
238
break ;
226
239
case TokenKind::uReg:
227
- Register-> ViewType = RegisterType::UReg;
240
+ Register. ViewType = RegisterType::UReg;
228
241
break ;
229
242
case TokenKind::sReg :
230
- Register-> ViewType = RegisterType::SReg;
243
+ Register. ViewType = RegisterType::SReg;
231
244
break ;
232
245
}
233
246
234
- if (handleUIntLiteral (& Register-> Number ))
247
+ if (handleUIntLiteral (Register. Number ))
235
248
return true ; // propogate NumericLiteralParser error
236
249
237
250
return false ;
238
251
}
239
252
240
- bool RootSignatureParser::handleUIntLiteral (uint32_t * X) {
253
+ bool RootSignatureParser::handleUIntLiteral (uint32_t & X) {
241
254
// Parse the numeric value and do semantic checks on its specification
242
255
clang::NumericLiteralParser Literal (CurToken.NumSpelling , CurToken.TokLoc ,
243
256
PP.getSourceManager (), PP.getLangOpts (),
@@ -256,7 +269,7 @@ bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
256
269
return true ;
257
270
}
258
271
259
- * X = Val.getExtValue ();
272
+ X = Val.getExtValue ();
260
273
return false ;
261
274
}
262
275
0 commit comments