Skip to content

Commit 425e8dc

Browse files
committed
[HLSL][RootSignature] Implement parsing of RootParamters
- Implement the ParseRootParameter methods in ParseHLSLRootSignature - Define the in-memory represenation of the RootFlag and adds it to the RootElement structure - Add testing of valid inputs to ParseHLSLRootSignatureTest.cpp
1 parent bee90e6 commit 425e8dc

File tree

4 files changed

+331
-0
lines changed

4 files changed

+331
-0
lines changed

clang/include/clang/Sema/ParseHLSLRootSignature.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H
1414
#define LLVM_CLANG_SEMA_PARSEHLSLROOTSIGNATURE_H
1515

16+
#include "llvm/ADT/APInt.h"
1617
#include "llvm/ADT/SmallVector.h"
1718
#include "llvm/ADT/StringRef.h"
1819
#include "llvm/ADT/StringSwitch.h"
@@ -43,11 +44,22 @@ class Parser {
4344
bool ParseRootElement();
4445

4546
bool ParseRootFlags();
47+
bool ParseRootParameter();
48+
49+
// Helper methods
50+
bool ParseAssign();
51+
bool ParseComma();
52+
bool ParseOptComma();
53+
bool ParseRegister(Register &);
54+
bool ParseUnsignedInt(uint32_t &Number);
55+
4656
// Enum methods
4757
template <typename EnumType>
4858
bool ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
4959
EnumType &Enum);
60+
bool ParseRootDescriptorFlag(RootDescriptorFlags &Flag);
5061
bool ParseRootFlag(RootFlags &Flag);
62+
bool ParseVisibility(ShaderVisibility &Visibility);
5163

5264
// Internal state used when parsing
5365
StringRef Buffer;

clang/lib/Sema/ParseHLSLRootSignature.cpp

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,148 @@ bool Parser::ParseRootFlags() {
3838
return false;
3939
}
4040

41+
bool Parser::ParseRootParameter() {
42+
RootParameter Parameter;
43+
Parameter.Type = llvm::StringSwitch<RootType>(Token)
44+
.Case("CBV", RootType::CBV)
45+
.Case("SRV", RootType::SRV)
46+
.Case("UAV", RootType::UAV)
47+
.Case("RootConstants", RootType::Constants);
48+
// Will never reach here as Token was just verified in dispatch
49+
50+
// Remove any whitespace
51+
Buffer = Buffer.drop_while(isspace);
52+
53+
// Retreive mandatory num32BitConstant arg for RootConstants
54+
if (Parameter.Type == RootType::Constants) {
55+
if (!Buffer.consume_front("num32BitConstants"))
56+
return ReportError();
57+
58+
if (ParseAssign())
59+
return ReportError();
60+
61+
if (ParseUnsignedInt(Parameter.Num32BitConstants))
62+
return ReportError();
63+
64+
if (ParseOptComma())
65+
return ReportError();
66+
}
67+
68+
// Retrieve mandatory register
69+
if (ParseRegister(Parameter.Register))
70+
return true;
71+
72+
if (ParseOptComma())
73+
return ReportError();
74+
75+
// Parse common optional space arg
76+
if (Buffer.consume_front("space")) {
77+
if (ParseAssign())
78+
return ReportError();
79+
80+
if (ParseUnsignedInt(Parameter.Space))
81+
return ReportError();
82+
83+
if (ParseOptComma())
84+
return ReportError();
85+
}
86+
87+
// Parse common optional visibility arg
88+
if (Buffer.consume_front("visibility")) {
89+
if (ParseAssign())
90+
return ReportError();
91+
92+
if (ParseVisibility(Parameter.Visibility))
93+
return ReportError();
94+
95+
if (ParseOptComma())
96+
return ReportError();
97+
}
98+
99+
// Retreive optional flags arg for non-RootConstants
100+
if (Parameter.Type != RootType::Constants && Buffer.consume_front("flags")) {
101+
if (ParseAssign())
102+
return ReportError();
103+
104+
if (ParseRootDescriptorFlag(Parameter.Flags))
105+
return ReportError();
106+
107+
// Remove trailing whitespace
108+
Buffer = Buffer.drop_while(isspace);
109+
}
110+
111+
// Create and push the root element on the parsed elements
112+
Elements->push_back(RootElement(Parameter));
113+
return false;
114+
}
115+
116+
// Helper Parser methods
117+
118+
// Parses " = " with varying whitespace
119+
bool Parser::ParseAssign() {
120+
Buffer = Buffer.drop_while(isspace);
121+
if (!Buffer.starts_with('='))
122+
return true;
123+
Buffer = Buffer.drop_front();
124+
Buffer = Buffer.drop_while(isspace);
125+
return false;
126+
}
127+
128+
// Parses ", " with varying whitespace
129+
bool Parser::ParseComma() {
130+
if (!Buffer.starts_with(','))
131+
return true;
132+
Buffer = Buffer.drop_front();
133+
Buffer = Buffer.drop_while(isspace);
134+
return false;
135+
}
136+
137+
// Parses ", " if possible. When successful we expect another parameter, and
138+
// return no error, otherwise we expect that we should be at the end of the
139+
// root element and return an error if this isn't the case
140+
bool Parser::ParseOptComma() {
141+
if (!ParseComma())
142+
return false;
143+
Buffer = Buffer.drop_while(isspace);
144+
return !Buffer.starts_with(')');
145+
}
146+
147+
bool Parser::ParseRegister(Register &Register) {
148+
// Parse expected register type ('b', 't', 'u', 's')
149+
if (Buffer.empty())
150+
return ReportError();
151+
152+
// Get type character
153+
Token = Buffer.take_front();
154+
Buffer = Buffer.drop_front();
155+
156+
auto MaybeType = llvm::StringSwitch<std::optional<RegisterType>>(Token)
157+
.Case("b", RegisterType::BReg)
158+
.Case("t", RegisterType::TReg)
159+
.Case("u", RegisterType::UReg)
160+
.Case("s", RegisterType::SReg)
161+
.Default(std::nullopt);
162+
if (!MaybeType)
163+
return ReportError();
164+
Register.ViewType = *MaybeType;
165+
166+
if (ParseUnsignedInt(Register.Number))
167+
return ReportError();
168+
169+
return false;
170+
}
171+
172+
// Parses "[0-9+]" as an unsigned int
173+
bool Parser::ParseUnsignedInt(uint32_t &Number) {
174+
StringRef NumString = Buffer.take_while(isdigit);
175+
APInt X = APInt(32, 0);
176+
if (NumString.getAsInteger(/*radix=*/10, X))
177+
return true;
178+
Number = X.getZExtValue();
179+
Buffer = Buffer.drop_front(NumString.size());
180+
return false;
181+
}
182+
41183
template <typename EnumType>
42184
bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
43185
EnumType &Enum) {
@@ -57,6 +199,18 @@ bool Parser::ParseEnum(SmallVector<std::pair<StringLiteral, EnumType>> Mapping,
57199
return false;
58200
}
59201

202+
bool Parser::ParseRootDescriptorFlag(RootDescriptorFlags &Flag) {
203+
SmallVector<std::pair<StringLiteral, RootDescriptorFlags>> Mapping = {
204+
{"0", RootDescriptorFlags::None},
205+
{"DATA_VOLATILE", RootDescriptorFlags::DataVolatile},
206+
{"DATA_STATIC_WHILE_SET_AT_EXECUTE",
207+
RootDescriptorFlags::DataStaticWhileSetAtExecute},
208+
{"DATA_STATIC", RootDescriptorFlags::DataStatic},
209+
};
210+
211+
return ParseEnum<RootDescriptorFlags>(Mapping, Flag);
212+
}
213+
60214
bool Parser::ParseRootFlag(RootFlags &Flag) {
61215
SmallVector<std::pair<StringLiteral, RootFlags>> Mapping = {
62216
{"0", RootFlags::None},
@@ -83,16 +237,36 @@ bool Parser::ParseRootFlag(RootFlags &Flag) {
83237
return ParseEnum<RootFlags>(Mapping, Flag);
84238
}
85239

240+
bool Parser::ParseVisibility(ShaderVisibility &Visibility) {
241+
SmallVector<std::pair<StringLiteral, ShaderVisibility>> Mapping = {
242+
{"SHADER_VISIBILITY_ALL", ShaderVisibility::All},
243+
{"SHADER_VISIBILITY_VERTEX", ShaderVisibility::Vertex},
244+
{"SHADER_VISIBILITY_HULL", ShaderVisibility::Hull},
245+
{"SHADER_VISIBILITY_DOMAIN", ShaderVisibility::Domain},
246+
{"SHADER_VISIBILITY_GEOMETRY", ShaderVisibility::Geometry},
247+
{"SHADER_VISIBILITY_PIXEL", ShaderVisibility::Pixel},
248+
{"SHADER_VISIBILITY_AMPLIFICATION", ShaderVisibility::Amplification},
249+
{"SHADER_VISIBILITY_MESH", ShaderVisibility::Mesh},
250+
};
251+
252+
return ParseEnum<ShaderVisibility>(Mapping, Visibility);
253+
}
254+
86255
bool Parser::ParseRootElement() {
87256
// Define different ParserMethods to use StringSwitch for dispatch
88257
enum class ParserMethod {
89258
ReportError,
90259
ParseRootFlags,
260+
ParseRootParameter,
91261
};
92262

93263
// Retreive which method should be used
94264
auto Method = llvm::StringSwitch<ParserMethod>(Token)
95265
.Case("RootFlags", ParserMethod::ParseRootFlags)
266+
.Case("RootConstants", ParserMethod::ParseRootParameter)
267+
.Case("CBV", ParserMethod::ParseRootParameter)
268+
.Case("SRV", ParserMethod::ParseRootParameter)
269+
.Case("UAV", ParserMethod::ParseRootParameter)
96270
.Default(ParserMethod::ReportError);
97271

98272
// Dispatch on the correct method
@@ -101,6 +275,8 @@ bool Parser::ParseRootElement() {
101275
return ReportError();
102276
case ParserMethod::ParseRootFlags:
103277
return ParseRootFlags();
278+
case ParserMethod::ParseRootParameter:
279+
return ParseRootParameter();
104280
}
105281
}
106282

clang/unittests/Sema/ParseHLSLRootSignatureTest.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,102 @@ TEST(ParseHLSLRootSignature, ValidRootFlags) {
5555
ASSERT_EQ(RootFlags::ValidFlags, RootElements[0].Flags);
5656
}
5757

58+
TEST(ParseHLSLRootSignature, MandatoryRootConstant) {
59+
llvm::StringRef RootFlagString = "RootConstants(num32BitConstants = 4, b42)";
60+
llvm::SmallVector<RootElement> RootElements;
61+
Parser Parser(RootFlagString, &RootElements);
62+
ASSERT_FALSE(Parser.Parse());
63+
ASSERT_EQ(RootElements.size(), (unsigned long)1);
64+
65+
RootParameter Parameter = RootElements[0].Parameter;
66+
ASSERT_EQ(RootType::Constants, Parameter.Type);
67+
ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
68+
ASSERT_EQ((uint32_t)42, Parameter.Register.Number);
69+
ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants);
70+
ASSERT_EQ((uint32_t)0, Parameter.Space);
71+
ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
72+
}
73+
74+
TEST(ParseHLSLRootSignature, OptionalRootConstant) {
75+
llvm::StringRef RootFlagString =
76+
"RootConstants(num32BitConstants = 4, b42, space = 4, visibility = "
77+
"SHADER_VISIBILITY_DOMAIN)";
78+
llvm::SmallVector<RootElement> RootElements;
79+
Parser Parser(RootFlagString, &RootElements);
80+
ASSERT_FALSE(Parser.Parse());
81+
ASSERT_EQ(RootElements.size(), (unsigned long)1);
82+
83+
RootParameter Parameter = RootElements[0].Parameter;
84+
ASSERT_EQ(RootType::Constants, Parameter.Type);
85+
ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
86+
ASSERT_EQ((uint32_t)42, Parameter.Register.Number);
87+
ASSERT_EQ((uint32_t)4, Parameter.Num32BitConstants);
88+
ASSERT_EQ((uint32_t)4, Parameter.Space);
89+
ASSERT_EQ(ShaderVisibility::Domain, Parameter.Visibility);
90+
}
91+
92+
TEST(ParseHLSLRootSignature, DefaultRootCBV) {
93+
llvm::StringRef ViewsString = "CBV(b0)";
94+
llvm::SmallVector<RootElement> RootElements;
95+
Parser Parser(ViewsString, &RootElements);
96+
ASSERT_FALSE(Parser.Parse());
97+
ASSERT_EQ(RootElements.size(), (unsigned long)1);
98+
99+
RootParameter Parameter = RootElements[0].Parameter;
100+
ASSERT_EQ(RootType::CBV, Parameter.Type);
101+
ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
102+
ASSERT_EQ((uint32_t)0, Parameter.Register.Number);
103+
ASSERT_EQ(RootDescriptorFlags::None, Parameter.Flags);
104+
ASSERT_EQ((uint32_t)0, Parameter.Space);
105+
ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
106+
}
107+
108+
TEST(ParseHLSLRootSignature, SampleRootCBV) {
109+
llvm::StringRef ViewsString = "CBV(b982374, space = 1, flags = DATA_STATIC)";
110+
llvm::SmallVector<RootElement> RootElements;
111+
Parser Parser(ViewsString, &RootElements);
112+
ASSERT_FALSE(Parser.Parse());
113+
ASSERT_EQ(RootElements.size(), (unsigned long)1);
114+
115+
RootParameter Parameter = RootElements[0].Parameter;
116+
ASSERT_EQ(RootType::CBV, Parameter.Type);
117+
ASSERT_EQ(RegisterType::BReg, Parameter.Register.ViewType);
118+
ASSERT_EQ((uint32_t)982374, Parameter.Register.Number);
119+
ASSERT_EQ(RootDescriptorFlags::DataStatic, Parameter.Flags);
120+
ASSERT_EQ((uint32_t)1, Parameter.Space);
121+
ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
122+
}
123+
124+
TEST(ParseHLSLRootSignature, SampleRootSRV) {
125+
llvm::StringRef ViewsString = "SRV(t3, visibility = SHADER_VISIBILITY_MESH, "
126+
"flags = Data_Static_While_Set_At_Execute)";
127+
llvm::SmallVector<RootElement> RootElements;
128+
Parser Parser(ViewsString, &RootElements);
129+
ASSERT_FALSE(Parser.Parse());
130+
ASSERT_EQ(RootElements.size(), (unsigned long)1);
131+
132+
RootParameter Parameter = RootElements[0].Parameter;
133+
ASSERT_EQ(RootType::SRV, Parameter.Type);
134+
ASSERT_EQ(RegisterType::TReg, Parameter.Register.ViewType);
135+
ASSERT_EQ((uint32_t)3, Parameter.Register.Number);
136+
ASSERT_EQ(RootDescriptorFlags::DataStaticWhileSetAtExecute, Parameter.Flags);
137+
ASSERT_EQ((uint32_t)0, Parameter.Space);
138+
ASSERT_EQ(ShaderVisibility::Mesh, Parameter.Visibility);
139+
}
140+
141+
TEST(ParseHLSLRootSignature, SampleRootUAV) {
142+
llvm::StringRef ViewsString = "UAV(u0, flags = DATA_VOLATILE)";
143+
llvm::SmallVector<RootElement> RootElements;
144+
Parser Parser(ViewsString, &RootElements);
145+
ASSERT_FALSE(Parser.Parse());
146+
ASSERT_EQ(RootElements.size(), (unsigned long)1);
147+
148+
RootParameter Parameter = RootElements[0].Parameter;
149+
ASSERT_EQ(RootType::UAV, Parameter.Type);
150+
ASSERT_EQ(RegisterType::UReg, Parameter.Register.ViewType);
151+
ASSERT_EQ((uint32_t)0, Parameter.Register.Number);
152+
ASSERT_EQ(RootDescriptorFlags::DataVolatile, Parameter.Flags);
153+
ASSERT_EQ((uint32_t)0, Parameter.Space);
154+
ASSERT_EQ(ShaderVisibility::All, Parameter.Visibility);
155+
}
58156
} // anonymous namespace

0 commit comments

Comments
 (0)