Skip to content

Commit fb4e85d

Browse files
author
Finn Plummer
committed
define ParseUInt and ParseRegister to plug into parameters
1 parent 71ae661 commit fb4e85d

File tree

5 files changed

+260
-14
lines changed

5 files changed

+260
-14
lines changed

clang/include/clang/Basic/DiagnosticParseKinds.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,5 +1835,6 @@ def err_hlsl_unexpected_end_of_params
18351835
: Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
18361836
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
18371837
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
1838+
def err_hlsl_number_literal_overflow : Error<"integer literal is too large to be represented as a 32-bit %select{signed |}0 integer type">;
18381839

18391840
} // end of Parser diagnostics

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ class RootSignatureParser {
101101
llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &Params,
102102
llvm::SmallDenseSet<TokenKind> &Mandatory);
103103

104+
/// Parameter parse methods corresponding to a ParamType
105+
bool parseUIntParam(uint32_t *X);
106+
bool parseRegister(llvm::hlsl::rootsig::Register *Reg);
107+
108+
/// Use NumericLiteralParser to convert CurToken.NumSpelling into a unsigned
109+
/// 32-bit integer
110+
bool handleUIntLiteral(uint32_t *X);
111+
104112
/// Invoke the Lexer to consume a token and update CurToken with the result
105113
void consumeNextToken() { CurToken = Lexer.ConsumeToken(); }
106114

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include "clang/Parse/ParseHLSLRootSignature.h"
1010

11+
#include "clang/Lex/LiteralSupport.h"
12+
1113
#include "llvm/Support/raw_ostream.h"
1214

1315
using namespace llvm::hlsl::rootsig;
@@ -87,46 +89,73 @@ bool RootSignatureParser::parseDescriptorTableClause() {
8789
CurToken.Kind == TokenKind::kw_UAV ||
8890
CurToken.Kind == TokenKind::kw_Sampler) &&
8991
"Expects to only be invoked starting at given keyword");
92+
TokenKind ParamKind = CurToken.Kind; // retain for diagnostics
9093

9194
DescriptorTableClause Clause;
92-
switch (CurToken.Kind) {
95+
TokenKind ExpectedRegister;
96+
switch (ParamKind) {
9397
default:
9498
llvm_unreachable("Switch for consumed token was not provided");
9599
case TokenKind::kw_CBV:
96100
Clause.Type = ClauseType::CBuffer;
101+
ExpectedRegister = TokenKind::bReg;
97102
break;
98103
case TokenKind::kw_SRV:
99104
Clause.Type = ClauseType::SRV;
105+
ExpectedRegister = TokenKind::tReg;
100106
break;
101107
case TokenKind::kw_UAV:
102108
Clause.Type = ClauseType::UAV;
109+
ExpectedRegister = TokenKind::uReg;
103110
break;
104111
case TokenKind::kw_Sampler:
105112
Clause.Type = ClauseType::Sampler;
113+
ExpectedRegister = TokenKind::sReg;
106114
break;
107115
}
108116

109117
if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
110-
CurToken.Kind))
118+
ParamKind))
111119
return true;
112120

113-
if (consumeExpectedToken(TokenKind::pu_r_paren, diag::err_expected_after,
114-
CurToken.Kind))
121+
llvm::SmallDenseMap<TokenKind, ParamType> Params = {
122+
{ExpectedRegister, &Clause.Register},
123+
{TokenKind::kw_space, &Clause.Space},
124+
};
125+
llvm::SmallDenseSet<TokenKind> Mandatory = {
126+
ExpectedRegister,
127+
};
128+
129+
if (parseParams(Params, Mandatory))
115130
return true;
116131

132+
if (!tryConsumeExpectedToken(TokenKind::pu_r_paren)) {
133+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_unexpected_end_of_params)
134+
<< /*expected=*/TokenKind::pu_r_paren
135+
<< /*param of=*/ParamKind;
136+
return true;
137+
}
138+
117139
Elements.push_back(Clause);
118140
return false;
119141
}
120142

121-
// Helper struct so that we can use the overloaded notation of std::visit
143+
// Helper struct defined to use the overloaded notation of std::visit.
122144
template <class... Ts> struct ParseMethods : Ts... { using Ts::operator()...; };
123145
template <class... Ts> ParseMethods(Ts...) -> ParseMethods<Ts...>;
124146

125147
bool RootSignatureParser::parseParam(ParamType Ref) {
126-
bool Error = true;
127-
std::visit(ParseMethods{}, Ref);
128-
129-
return Error;
148+
return std::visit(
149+
ParseMethods{
150+
[this](Register *X) -> bool { return parseRegister(X); },
151+
[this](uint32_t *X) -> bool {
152+
return consumeExpectedToken(TokenKind::pu_equal,
153+
diag::err_expected_after,
154+
CurToken.Kind) ||
155+
parseUIntParam(X);
156+
},
157+
},
158+
Ref);
130159
}
131160

132161
bool RootSignatureParser::parseParams(
@@ -169,6 +198,67 @@ bool RootSignatureParser::parseParams(
169198
return !AllMandatoryDefined;
170199
}
171200

201+
bool RootSignatureParser::parseUIntParam(uint32_t *X) {
202+
assert(CurToken.Kind == TokenKind::pu_equal &&
203+
"Expects to only be invoked starting at given keyword");
204+
tryConsumeExpectedToken(TokenKind::pu_plus);
205+
return consumeExpectedToken(TokenKind::int_literal, diag::err_expected_after,
206+
CurToken.Kind) ||
207+
handleUIntLiteral(X);
208+
}
209+
210+
bool RootSignatureParser::parseRegister(Register *Register) {
211+
assert(
212+
(CurToken.Kind == TokenKind::bReg || CurToken.Kind == TokenKind::tReg ||
213+
CurToken.Kind == TokenKind::uReg || CurToken.Kind == TokenKind::sReg) &&
214+
"Expects to only be invoked starting at given keyword");
215+
216+
switch (CurToken.Kind) {
217+
case TokenKind::bReg:
218+
Register->ViewType = RegisterType::BReg;
219+
break;
220+
case TokenKind::tReg:
221+
Register->ViewType = RegisterType::TReg;
222+
break;
223+
case TokenKind::uReg:
224+
Register->ViewType = RegisterType::UReg;
225+
break;
226+
case TokenKind::sReg:
227+
Register->ViewType = RegisterType::SReg;
228+
break;
229+
default:
230+
break; // Unreachable given Try + assert pattern
231+
}
232+
233+
if (handleUIntLiteral(&Register->Number))
234+
return true; // propogate NumericLiteralParser error
235+
236+
return false;
237+
}
238+
239+
bool RootSignatureParser::handleUIntLiteral(uint32_t *X) {
240+
// Parse the numeric value and do semantic checks on its specification
241+
clang::NumericLiteralParser Literal(CurToken.NumSpelling, CurToken.TokLoc,
242+
PP.getSourceManager(), PP.getLangOpts(),
243+
PP.getTargetInfo(), PP.getDiagnostics());
244+
if (Literal.hadError)
245+
return true; // Error has already been reported so just return
246+
247+
assert(Literal.isIntegerLiteral() && "IsNumberChar will only support digits");
248+
249+
llvm::APSInt Val = llvm::APSInt(32, false);
250+
if (Literal.GetIntegerValue(Val)) {
251+
// Report that the value has overflowed
252+
PP.getDiagnostics().Report(CurToken.TokLoc,
253+
diag::err_hlsl_number_literal_overflow)
254+
<< 0 << CurToken.NumSpelling;
255+
return true;
256+
}
257+
258+
*X = Val.getExtValue();
259+
return false;
260+
}
261+
172262
bool RootSignatureParser::peekExpectedToken(TokenKind Expected) {
173263
return peekExpectedToken(ArrayRef{Expected});
174264
}

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyTest) {
130130
TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
131131
const llvm::StringLiteral Source = R"cc(
132132
DescriptorTable(
133-
CBV(),
134-
SRV(),
135-
Sampler(),
136-
UAV()
133+
CBV(b0),
134+
SRV(space = 3, t42),
135+
Sampler(s987, space = +2),
136+
UAV(u4294967294)
137137
),
138138
DescriptorTable()
139139
)cc";
@@ -155,18 +155,34 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
155155
RootElement Elem = Elements[0];
156156
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
157157
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::CBuffer);
158+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
159+
RegisterType::BReg);
160+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 0u);
161+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
158162

159163
Elem = Elements[1];
160164
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
161165
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::SRV);
166+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
167+
RegisterType::TReg);
168+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 42u);
169+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 3u);
162170

163171
Elem = Elements[2];
164172
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
165173
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::Sampler);
174+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
175+
RegisterType::SReg);
176+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 987u);
177+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 2u);
166178

167179
Elem = Elements[3];
168180
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
169181
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Type, ClauseType::UAV);
182+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
183+
RegisterType::UReg);
184+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, 4294967294u);
185+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Space, 0u);
170186

171187
Elem = Elements[4];
172188
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
@@ -176,6 +192,32 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
176192
Elem = Elements[5];
177193
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));
178194
ASSERT_EQ(std::get<DescriptorTable>(Elem).NumClauses, 0u);
195+
196+
ASSERT_TRUE(Consumer->isSatisfied());
197+
}
198+
199+
TEST_F(ParseHLSLRootSignatureTest, ValidTrailingCommaTest) {
200+
// This test will checks we can handling trailing commas ','
201+
const llvm::StringLiteral Source = R"cc(
202+
DescriptorTable(
203+
CBV(b0, ),
204+
SRV(t42),
205+
)
206+
)cc";
207+
208+
TrivialModuleLoader ModLoader;
209+
auto PP = createPP(Source, ModLoader);
210+
auto TokLoc = SourceLocation();
211+
212+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
213+
SmallVector<RootElement> Elements;
214+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
215+
216+
// Test no diagnostics produced
217+
Consumer->setNoDiag();
218+
219+
ASSERT_FALSE(Parser.parse());
220+
179221
ASSERT_TRUE(Consumer->isSatisfied());
180222
}
181223

@@ -237,6 +279,102 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidParseUnexpectedEndOfStreamTest) {
237279

238280
// Test correct diagnostic produced - end of stream
239281
Consumer->setExpected(diag::err_expected_after);
282+
283+
ASSERT_TRUE(Parser.parse());
284+
285+
ASSERT_TRUE(Consumer->isSatisfied());
286+
}
287+
288+
TEST_F(ParseHLSLRootSignatureTest, InvalidMissingParameterTest) {
289+
// This test will check that the parsing fails due a mandatory
290+
// parameter (register) not being specified
291+
const llvm::StringLiteral Source = R"cc(
292+
DescriptorTable(
293+
CBV()
294+
)
295+
)cc";
296+
297+
TrivialModuleLoader ModLoader;
298+
auto PP = createPP(Source, ModLoader);
299+
auto TokLoc = SourceLocation();
300+
301+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
302+
SmallVector<RootElement> Elements;
303+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
304+
305+
// Test correct diagnostic produced
306+
Consumer->setExpected(diag::err_hlsl_rootsig_missing_param);
307+
ASSERT_TRUE(Parser.parse());
308+
309+
ASSERT_TRUE(Consumer->isSatisfied());
310+
}
311+
312+
TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedMandatoryParameterTest) {
313+
// This test will check that the parsing fails due the same mandatory
314+
// parameter being specified multiple times
315+
const llvm::StringLiteral Source = R"cc(
316+
DescriptorTable(
317+
CBV(b32, b84)
318+
)
319+
)cc";
320+
321+
TrivialModuleLoader ModLoader;
322+
auto PP = createPP(Source, ModLoader);
323+
auto TokLoc = SourceLocation();
324+
325+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
326+
SmallVector<RootElement> Elements;
327+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
328+
329+
// Test correct diagnostic produced
330+
Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
331+
ASSERT_TRUE(Parser.parse());
332+
333+
ASSERT_TRUE(Consumer->isSatisfied());
334+
}
335+
336+
TEST_F(ParseHLSLRootSignatureTest, InvalidRepeatedOptionalParameterTest) {
337+
// This test will check that the parsing fails due the same optional
338+
// parameter being specified multiple times
339+
const llvm::StringLiteral Source = R"cc(
340+
DescriptorTable(
341+
CBV(space = 2, space = 0)
342+
)
343+
)cc";
344+
345+
TrivialModuleLoader ModLoader;
346+
auto PP = createPP(Source, ModLoader);
347+
auto TokLoc = SourceLocation();
348+
349+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
350+
SmallVector<RootElement> Elements;
351+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
352+
353+
// Test correct diagnostic produced
354+
Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
355+
ASSERT_TRUE(Parser.parse());
356+
357+
ASSERT_TRUE(Consumer->isSatisfied());
358+
}
359+
360+
TEST_F(ParseHLSLRootSignatureTest, InvalidLexOverflowedNumberTest) {
361+
// This test will check that the lexing fails due to an integer overflow
362+
const llvm::StringLiteral Source = R"cc(
363+
DescriptorTable(
364+
CBV(b4294967296)
365+
)
366+
)cc";
367+
368+
TrivialModuleLoader ModLoader;
369+
auto PP = createPP(Source, ModLoader);
370+
auto TokLoc = SourceLocation();
371+
372+
hlsl::RootSignatureLexer Lexer(Source, TokLoc);
373+
SmallVector<RootElement> Elements;
374+
hlsl::RootSignatureParser Parser(Elements, Lexer, *PP);
375+
376+
// Test correct diagnostic produced
377+
Consumer->setExpected(diag::err_hlsl_number_literal_overflow);
240378
ASSERT_TRUE(Parser.parse());
241379

242380
ASSERT_TRUE(Consumer->isSatisfied());

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ namespace rootsig {
2323

2424
// Definitions of the in-memory data layout structures
2525

26+
// Models the different registers: bReg | tReg | uReg | sReg
27+
enum class RegisterType { BReg, TReg, UReg, SReg };
28+
struct Register {
29+
RegisterType ViewType;
30+
uint32_t Number;
31+
};
32+
2633
// Models the end of a descriptor table and stores its visibility
2734
struct DescriptorTable {
2835
uint32_t NumClauses = 0; // The number of clauses in the table
@@ -32,6 +39,8 @@ struct DescriptorTable {
3239
using ClauseType = llvm::dxil::ResourceClass;
3340
struct DescriptorTableClause {
3441
ClauseType Type;
42+
Register Register;
43+
uint32_t Space = 0;
3544
};
3645

3746
// Models RootElement : DescriptorTable | DescriptorTableClause
@@ -41,7 +50,7 @@ using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
4150
// RootElement. Each variant of ParamType is expected to have a Parse method
4251
// defined that will be dispatched on when we are attempting to parse a
4352
// parameter
44-
using ParamType = std::variant<std::monostate>;
53+
using ParamType = std::variant<uint32_t *, Register *>;
4554

4655
} // namespace rootsig
4756
} // namespace hlsl

0 commit comments

Comments
 (0)