Skip to content

Commit 180c33c

Browse files
author
Finn Plummer
committed
define ParseUInt and ParseRegister to plug into parameters
1 parent 9ff87eb commit 180c33c

File tree

5 files changed

+257
-14
lines changed

5 files changed

+257
-14
lines changed

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,5 +1835,6 @@ def err_hlsl_unexpected_end_of_params
18351835
: Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
18361836
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
18371837
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
1838+
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
18381839

18391840
} // end of Parser diagnostics

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ class RootSignatureParser {
101101
llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
102102
llvm::SmallDenseSet<TokenKind> &Mandatory);
103103

104+
/// Parameter parse methods corresponding to a ParamType
105+
bool parseUIntParam(uint32_t *X);
106+
bool parseRegister(llvm::hlsl::rootsig::Register *Reg);
107+
108+
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
109+
/// 32-bit integer
110+
bool handleUIntLiteral(uint32_t *X);
111+
104112
/// Invoke the Lexer to consume a token and update CurToken with the result
105113
void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
106114

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "clang/Parse/ParseHLSLRootSignature.h"
1010

11+
#include "clang/Lex/LiteralSupport.h"
12+
1113
#include "llvm/Support/raw_ostream.h"
1214

1315
using namespace llvm::hlsl::rootsig;
@@ -89,46 +91,70 @@ bool RootSignatureParser::parseDescriptorTableClause() {
8991
CurToken.TokKind == TokenKind::kw_UAV ||
9092
CurToken.TokKind == TokenKind::kw_Sampler) &&
9193
"Expects to only be invoked starting at given keyword");
94+
TokenKind ParamKind = CurToken.TokKind; // retain for diagnostics
9295

9396
DescriptorTableClause Clause;
94-
switch (CurToken.TokKind) {
97+
TokenKind ExpectedRegister;
98+
switch (ParamKind) {
9599
default:
96100
llvm_unreachable("Switch for consumed token was not provided");
97101
case TokenKind::kw_CBV:
98102
Clause.Type = ClauseType::CBuffer;
103+
ExpectedRegister = TokenKind::bReg;
99104
break;
100105
case TokenKind::kw_SRV:
101106
Clause.Type = ClauseType::SRV;
107+
ExpectedRegister = TokenKind::tReg;
102108
break;
103109
case TokenKind::kw_UAV:
104110
Clause.Type = ClauseType::UAV;
111+
ExpectedRegister = TokenKind::uReg;
105112
break;
106113
case TokenKind::kw_Sampler:
107114
Clause.Type = ClauseType::Sampler;
115+
ExpectedRegister = TokenKind::sReg;
108116
break;
109117
}
110118

111119
if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
112-
CurToken.TokKind))
120+
ParamKind))
113121
return true;
114122

115-
if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after,
116-
CurToken.TokKind))
123+
llvm::SmallDenseMap<TokenKind, ParamType> Params = {
124+
{ExpectedRegister, &Clause.Register},
125+
{TokenKind::kw_space, &Clause.Space},
126+
};
127+
llvm::SmallDenseSet<TokenKind> Mandatory = {
128+
ExpectedRegister,
129+
};
130+
131+
if (parseParams(Params, Mandatory))
132+
return true;
133+
134+
if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_hlsl_unexpected_end_of_params,
135+
ParamKind))
117136
return true;
118137

119138
Elements.push_back(Clause);
120139
return false;
121140
}
122141

123-
// Helper struct so that we can use the overloaded notation of std::visit
142+
// Helper struct defined to use the overloaded notation of std::visit.
124143
template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
125144
template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
126145

127146
bool RootSignatureParser::parseParam(ParamType Ref) {
128-
bool Error = true;
129-
std::visit(ParseMethods{}, Ref);
130-
131-
return Error;
147+
return std::visit(
148+
ParseMethods{
149+
[this](Register *X) -> bool { return parseRegister(X); },
150+
[this](uint32_t *X) -> bool {
151+
return consumeExpectedToken(TokenKind::pu_equal,
152+
diag::err_expected_after,
153+
CurToken.Kind) ||
154+
parseUIntParam(X);
155+
},
156+
},
157+
Ref);
132158
}
133159

134160
bool RootSignatureParser::parseParams(
@@ -171,6 +197,67 @@ bool RootSignatureParser::parseParams(
171197
return !AllMandatoryDefined;
172198
}
173199

200+
bool RootSignatureParser::parseUIntParam(uint32_t *X) {
201+
assert(CurToken.Kind == TokenKind::pu_equal &&
202+
"Expects to only be invoked starting at given keyword");
203+
tryConsumeExpectedToken(TokenKind::pu_plus);
204+
return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
205+
CurToken.Kind) ||
206+
handleUIntLiteral(X);
207+
}
208+
209+
bool RootSignatureParser::parseRegister(Register *Register) {
210+
assert(
211+
(CurToken.Kind == TokenKind::bReg || CurToken.Kind == TokenKind::tReg ||
212+
CurToken.Kind == TokenKind::uReg || CurToken.Kind == TokenKind::sReg) &&
213+
"Expects to only be invoked starting at given keyword");
214+
215+
switch (CurToken.Kind) {
216+
case TokenKind::bReg:
217+
Register->ViewType = RegisterType::BReg;
218+
break;
219+
case TokenKind::tReg:
220+
Register->ViewType = RegisterType::TReg;
221+
break;
222+
case TokenKind::uReg:
223+
Register->ViewType = RegisterType::UReg;
224+
break;
225+
case TokenKind::sReg:
226+
Register->ViewType = RegisterType::SReg;
227+
break;
228+
default:
229+
break; // Unreachable given Try + assert pattern
230+
}
231+
232+
if (handleUIntLiteral(&Register->Number))
233+
return true; // propogate NumericLiteralParser error
234+
235+
return false;
236+
}
237+
238+
bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
239+
// Parse the numeric value and do semantic checks on its specification
240+
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
241+
PP.getSourceManager(), PP.getLangOpts(),
242+
PP.getTargetInfo(), PP.getDiagnostics());
243+
if (Literal.hadError)
244+
return true; // Error has already been reported so just return
245+
246+
assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
247+
248+
llvm::APSInt Val = llvm::APSInt(32, false);
249+
if (Literal.GetIntegerValue(Val)) {
250+
// Report that the value has overflowed
251+
PP.getDiagnostics().Report(CurToken.TokLoc,
252+
diag::err_hlsl_number_literal_overflow)
253+
<< 0 << CurToken.NumSpelling;
254+
return true;
255+
}
256+
257+
*X = Val.getExtValue();
258+
return false;
259+
}
260+
174261
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
175262
return peekExpectedToken(ArrayRef{Expected});
176263
}

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
130130
TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
131131
const llvm::StringLiteral Source = R"cc(
132132
DescriptorTable(
133-
CBV(),
134-
SRV(),
135-
Sampler(),
136-
UAV()
133+
CBV(b0),
134+
SRV(space = 3, t42),
135+
Sampler(s987, space = +2),
136+
UAV(u4294967294)
137137
),
138138
DescriptorTable()
139139
)cc";
@@ -155,18 +155,34 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
155155
RootElement Elem = Elements[0];
156156
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
157157
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
158+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
159+
RegisterType::BReg);
160+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 0u);
161+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
158162

159163
Elem = Elements[1];
160164
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
161165
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
166+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
167+
RegisterType::TReg);
168+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 42u);
169+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
162170

163171
Elem = Elements[2];
164172
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
165173
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
174+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
175+
RegisterType::SReg);
176+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 987u);
177+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
166178

167179
Elem = Elements[3];
168180
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
169181
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
182+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
183+
RegisterType::UReg);
184+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 4294967294u);
185+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
170186

171187
Elem = Elements[4];
172188
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -176,6 +192,32 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
176192
Elem = Elements[5];
177193
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
178194
ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u);
195+
196+
ASSERT_TRUE(Consumer->isSatisfied());
197+
}
198+
199+
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
200+
// This test will checks we can handling trailing commas ','
201+
const llvm::StringLiteral Source = R"cc(
202+
DescriptorTable(
203+
CBV(b0, ),
204+
SRV(t42),
205+
)
206+
)cc";
207+
208+
TrivialModuleLoader ModLoader;
209+
auto PP = createPP(Source, ModLoader);
210+
auto TokLoc = SourceLocation();
211+
212+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
213+
SmallVector<RootElement> Elements;
214+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
215+
216+
// Test no diagnostics produced
217+
Consumer->setNoDiag();
218+
219+
ASSERT_FALSE(Parser.parse());
220+
179221
ASSERT_TRUE(Consumer->isSatisfied());
180222
}
181223

@@ -237,6 +279,102 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) {
237279

238280
// Test correct diagnostic produced - end of stream
239281
Consumer->setExpected(diag::err_expected_after);
282+
283+
ASSERT_TRUE(Parser.parse());
284+
285+
ASSERT_TRUE(Consumer->isSatisfied());
286+
}
287+
288+
TEST_F(ParseHLSLRootSignatureTest, InvalidMissingParameterTest) {
289+
// This test will check that the parsing fails due a mandatory
290+
// parameter (register) not being specified
291+
const llvm::StringLiteral Source = R"cc(
292+
DescriptorTable(
293+
CBV()
294+
)
295+
)cc";
296+
297+
TrivialModuleLoader ModLoader;
298+
auto PP = createPP(Source, ModLoader);
299+
auto TokLoc = SourceLocation();
300+
301+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
302+
SmallVector<RootElement> Elements;
303+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
304+
305+
// Test correct diagnostic produced
306+
Consumer->setExpected(diag::err_hlsl_rootsig_missing_param);
307+
ASSERT_TRUE(Parser.parse());
308+
309+
ASSERT_TRUE(Consumer->isSatisfied());
310+
}
311+
312+
TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryParameterTest) {
313+
// This test will check that the parsing fails due the same mandatory
314+
// parameter being specified multiple times
315+
const llvm::StringLiteral Source = R"cc(
316+
DescriptorTable(
317+
CBV(b32, b84)
318+
)
319+
)cc";
320+
321+
TrivialModuleLoader ModLoader;
322+
auto PP = createPP(Source, ModLoader);
323+
auto TokLoc = SourceLocation();
324+
325+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
326+
SmallVector<RootElement> Elements;
327+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
328+
329+
// Test correct diagnostic produced
330+
Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
331+
ASSERT_TRUE(Parser.parse());
332+
333+
ASSERT_TRUE(Consumer->isSatisfied());
334+
}
335+
336+
TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalParameterTest) {
337+
// This test will check that the parsing fails due the same optional
338+
// parameter being specified multiple times
339+
const llvm::StringLiteral Source = R"cc(
340+
DescriptorTable(
341+
CBV(space = 2, space = 0)
342+
)
343+
)cc";
344+
345+
TrivialModuleLoader ModLoader;
346+
auto PP = createPP(Source, ModLoader);
347+
auto TokLoc = SourceLocation();
348+
349+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
350+
SmallVector<RootElement> Elements;
351+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
352+
353+
// Test correct diagnostic produced
354+
Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
355+
ASSERT_TRUE(Parser.parse());
356+
357+
ASSERT_TRUE(Consumer->isSatisfied());
358+
}
359+
360+
TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
361+
// This test will check that the lexing fails due to an integer overflow
362+
const llvm::StringLiteral Source = R"cc(
363+
DescriptorTable(
364+
CBV(b4294967296)
365+
)
366+
)cc";
367+
368+
TrivialModuleLoader ModLoader;
369+
auto PP = createPP(Source, ModLoader);
370+
auto TokLoc = SourceLocation();
371+
372+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
373+
SmallVector<RootElement> Elements;
374+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
375+
376+
// Test correct diagnostic produced
377+
Consumer->setExpected(diag::err_hlsl_number_literal_overflow);
240378
ASSERT_TRUE(Parser.parse());
241379

242380
ASSERT_TRUE(Consumer->isSatisfied());

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ namespace rootsig {
2323

2424
// Definitions of the in-memory data layout structures
2525

26+
// Models the different registers: bReg | tReg | uReg | sReg
27+
enum class RegisterType { BReg, TReg, UReg, SReg };
28+
struct Register {
29+
RegisterType ViewType;
30+
uint32_t Number;
31+
};
32+
2633
// Models the end of a descriptor table and stores its visibility
2734
struct DescriptorTable {
2835
uint32_t NumClauses = 0; // The number of clauses in the table
@@ -32,6 +39,8 @@ struct DescriptorTable {
3239
using ClauseType = llvm::dxil::ResourceClass;
3340
struct DescriptorTableClause {
3441
ClauseType Type;
42+
Register Register;
43+
uint32_t Space = 0;
3544
};
3645

3746
// Models RootElement : DescriptorTable | DescriptorTableClause
@@ -41,7 +50,7 @@ using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
4150
// RootElement. Each variant of ParamType is expected to have a Parse method
4251
// defined that will be dispatched on when we are attempting to parse a
4352
// parameter
44-
using ParamType = std::variant<std::monostate>;
53+
using ParamType = std::variant<uint32_t *, Register *>;
4554

4655
} // namespace rootsig
4756
} // namespace hlsl

0 commit comments

Comments
 (0)