Skip to content

Commit b160219

Browse files
committed
add support for optional parameters
- use numDescriptors as an example
1 parent e7b54bc commit b160219

File tree

4 files changed

+85
-1
lines changed

4 files changed

+85
-1
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,22 @@ class RootSignatureParser {
116116
bool ParseDescriptorTableClause();
117117

118118
// Helper dispatch method
119+
//
120+
// These will switch on the Variant kind to dispatch to the respective Parse
121+
// method and store the parsed value back into Ref.
122+
//
123+
// It is helpful to have a generalized dispatch method so that when we need
124+
// to parse multiple optional parameters in any order, we can invoke this
125+
// method
126+
bool ParseParam(llvm::hlsl::rootsig::ParamType Ref);
127+
128+
// Parse as many optional parameters as possible in any order
129+
bool ParseOptionalParams(
130+
llvm::SmallDenseMap<TokenKind, llvm::hlsl::rootsig::ParamType> &RefMap);
131+
132+
// Common parsing helpers
119133
bool ParseRegister(llvm::hlsl::rootsig::Register *Reg);
134+
bool ParseUInt(uint32_t *X);
120135

121136
/// Invoke the lexer to consume a token and update CurToken with the result
122137
///

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,73 @@ bool RootSignatureParser::ParseDescriptorTableClause() {
320320
if (ParseRegister(&Clause.Register))
321321
return true;
322322

323+
// Parse optional paramaters
324+
llvm::SmallDenseMap<TokenKind, ParamType> RefMap = {
325+
{TokenKind::kw_numDescriptors, &Clause.NumDescriptors},
326+
};
327+
if (ParseOptionalParams({RefMap}))
328+
return true;
329+
323330
if (ConsumeExpectedToken(TokenKind::pu_r_paren))
324331
return true;
325332

326333
Elements.push_back(Clause);
327334
return false;
328335
}
329336

337+
// Helper struct so that we can use the overloaded notation of std::visit
338+
template <class... Ts> struct OverloadedMethods : Ts... {
339+
using Ts::operator()...;
340+
};
341+
template <class... Ts> OverloadedMethods(Ts...) -> OverloadedMethods<Ts...>;
342+
343+
bool RootSignatureParser::ParseParam(ParamType Ref) {
344+
if (ConsumeExpectedToken(TokenKind::pu_equal))
345+
return true;
346+
347+
bool Error;
348+
std::visit(OverloadedMethods{[&](uint32_t *X) { Error = ParseUInt(X); },
349+
}, Ref);
350+
351+
return Error;
352+
}
353+
354+
bool RootSignatureParser::ParseOptionalParams(
355+
llvm::SmallDenseMap<TokenKind, ParamType> &RefMap) {
356+
SmallVector<TokenKind> ParamKeywords;
357+
for (auto RefPair : RefMap)
358+
ParamKeywords.push_back(RefPair.first);
359+
360+
// Keep track of which keywords have been seen to report duplicates
361+
llvm::SmallDenseSet<TokenKind> Seen;
362+
363+
while (!TryConsumeExpectedToken(TokenKind::pu_comma)) {
364+
if (ConsumeExpectedToken(ParamKeywords))
365+
return true;
366+
367+
TokenKind ParamKind = CurToken.Kind;
368+
if (Seen.contains(ParamKind)) {
369+
return true;
370+
}
371+
Seen.insert(ParamKind);
372+
373+
if (ParseParam(RefMap[ParamKind]))
374+
return true;
375+
}
376+
377+
return false;
378+
}
379+
380+
bool RootSignatureParser::ParseUInt(uint32_t *X) {
381+
// Treat a postively signed integer as though it is unsigned to match DXC
382+
TryConsumeExpectedToken(TokenKind::pu_plus);
383+
if (ConsumeExpectedToken(TokenKind::int_literal))
384+
return true;
385+
386+
*X = CurToken.NumLiteral.getInt().getExtValue();
387+
return false;
388+
}
389+
330390
bool RootSignatureParser::ParseRegister(Register *Register) {
331391
switch (CurToken.Kind) {
332392
case TokenKind::bReg:

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
339339
const llvm::StringLiteral Source = R"cc(
340340
DescriptorTable(
341341
CBV(b0),
342-
SRV(t42),
342+
SRV(t42, numDescriptors = +4),
343343
Sampler(s987),
344344
UAV(u987234)
345345
),
@@ -364,6 +364,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
364364
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.ViewType,
365365
RegisterType::BReg);
366366
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number, (uint32_t)0);
367+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
367368

368369
Elem = Elements[1];
369370
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -372,6 +373,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
372373
RegisterType::TReg);
373374
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
374375
(uint32_t)42);
376+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)4);
375377

376378
Elem = Elements[2];
377379
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -380,6 +382,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
380382
RegisterType::SReg);
381383
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
382384
(uint32_t)987);
385+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
383386

384387
Elem = Elements[3];
385388
ASSERT_TRUE(std::holds_alternative<DescriptorTableClause>(Elem));
@@ -388,6 +391,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseDTClausesTest) {
388391
RegisterType::UReg);
389392
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).Register.Number,
390393
(uint32_t)987234);
394+
ASSERT_EQ(std::get<DescriptorTableClause>(Elem).NumDescriptors, (uint32_t)1);
391395

392396
Elem = Elements[4];
393397
ASSERT_TRUE(std::holds_alternative<DescriptorTable>(Elem));

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,16 @@ using ClauseType = llvm::dxil::ResourceClass;
4040
struct DescriptorTableClause {
4141
ClauseType Type;
4242
Register Register;
43+
uint32_t NumDescriptors = 1;
4344
};
4445

4546
// Models RootElement : DescriptorTable | DescriptorTableClause
4647
using RootElement = std::variant<DescriptorTable, DescriptorTableClause>;
4748

49+
// Models a reference to all assignment parameter types that any RootElement
50+
// may have. Things of the form: Keyword = Param
51+
using ParamType = std::variant<uint32_t *>;
52+
4853
} // namespace rootsig
4954
} // namespace hlsl
5055
} // namespace llvm

0 commit comments

Comments
 (0)