Skip to content

Commit 3daccf8

Browse files
committed
[HLSL][RootSignature] Add parsing of flags to RootParam
- defines RootDescriptorFlags in-memory representation - defines parseRootDescriptorFlags to be DXC compatible. This is why we support multiple `|` flags even validation will assert that only one flag is set... - add unit tests to demonstrate functionality
1 parent 2b64b15 commit 3daccf8

File tree

4 files changed

+105
-3
lines changed

4 files changed

+105
-3
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class RootSignatureParser {
9393
std::optional<llvm::hlsl::rootsig::Register> Reg;
9494
std::optional<uint32_t> Space;
9595
std::optional<llvm::hlsl::rootsig::ShaderVisibility> Visibility;
96+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags> Flags;
9697
};
9798
std::optional<ParsedRootDescriptorParams>
9899
parseRootDescriptorParams(RootSignatureToken::Kind RegType);
@@ -113,6 +114,8 @@ class RootSignatureParser {
113114

114115
/// Parsing methods of various enums
115116
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
117+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
118+
parseRootDescriptorFlags();
116119
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
117120
parseDescriptorRangeFlags();
118121

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
193193
ExpectedReg = TokenKind::uReg;
194194
break;
195195
}
196+
Param.setDefaultFlags();
196197

197198
auto Params = parseRootDescriptorParams(ExpectedReg);
198199
if (!Params.has_value())
@@ -214,6 +215,9 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
214215
if (Params->Visibility.has_value())
215216
Descriptor.Visibility = Params->Visibility.value();
216217

218+
if (Params->Flags.has_value())
219+
Param.Flags = Params->Flags.value();
220+
217221
if (consumeExpectedToken(TokenKind::pu_r_paren,
218222
diag::err_hlsl_unexpected_end_of_params,
219223
/*param of=*/TokenKind::kw_RootConstants))
@@ -475,6 +479,23 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
475479
return std::nullopt;
476480
Params.Visibility = Visibility;
477481
}
482+
483+
// `flags` `=` ROOT_DESCRIPTOR_FLAGS
484+
if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
485+
if (Params.Flags.has_value()) {
486+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
487+
<< CurToken.TokKind;
488+
return std::nullopt;
489+
}
490+
491+
if (consumeExpectedToken(TokenKind::pu_equal))
492+
return std::nullopt;
493+
494+
auto Flags = parseRootDescriptorFlags();
495+
if (!Flags.has_value())
496+
return std::nullopt;
497+
Params.Flags = Flags;
498+
}
478499
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
479500

480501
return Params;
@@ -654,6 +675,45 @@ RootSignatureParser::parseShaderVisibility() {
654675
return std::nullopt;
655676
}
656677

678+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
679+
RootSignatureParser::parseRootDescriptorFlags() {
680+
assert(CurToken.TokKind == TokenKind::pu_equal &&
681+
"Expects to only be invoked starting at given keyword");
682+
683+
// Handle the edge-case of '0' to specify no flags set
684+
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
685+
if (!verifyZeroFlag()) {
686+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
687+
return std::nullopt;
688+
}
689+
return RootDescriptorFlags::None;
690+
}
691+
692+
TokenKind Expected[] = {
693+
#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
694+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
695+
};
696+
697+
std::optional<RootDescriptorFlags> Flags;
698+
699+
do {
700+
if (tryConsumeExpectedToken(Expected)) {
701+
switch (CurToken.TokKind) {
702+
#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) \
703+
case TokenKind::en_##NAME: \
704+
Flags = \
705+
maybeOrFlag<RootDescriptorFlags>(Flags, RootDescriptorFlags::NAME); \
706+
break;
707+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
708+
default:
709+
llvm_unreachable("Switch for consumed enum token was not provided");
710+
}
711+
}
712+
} while (tryConsumeExpectedToken(TokenKind::pu_or));
713+
714+
return Flags;
715+
}
716+
657717
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
658718
RootSignatureParser::parseDescriptorRangeFlags() {
659719
assert(CurToken.TokKind == TokenKind::pu_equal &&

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,11 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
347347
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
348348
const llvm::StringLiteral Source = R"cc(
349349
CBV(b0),
350-
SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY),
351-
UAV(visibility = SHADER_VISIBILITY_HULL, u34893247)
350+
SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY,
351+
flags = DATA_VOLATILE | DATA_STATIC | DATA_STATIC_WHILE_SET_AT_EXECUTE
352+
),
353+
UAV(visibility = SHADER_VISIBILITY_HULL, u34893247),
354+
CBV(b0, flags = 0),
352355
)cc";
353356

354357
TrivialModuleLoader ModLoader;
@@ -364,7 +367,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
364367

365368
ASSERT_FALSE(Parser.parse());
366369

367-
ASSERT_EQ(Elements.size(), 3u);
370+
ASSERT_EQ(Elements.size(), 4u);
368371

369372
RootElement Elem = Elements[0];
370373
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
@@ -373,6 +376,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
373376
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 0u);
374377
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
375378
ASSERT_EQ(std::get<RootDescriptor>(Elem).Visibility, ShaderVisibility::All);
379+
ASSERT_EQ(std::get<RootParam>(Elem).Flags,
380+
RootDescriptorFlags::DataStaticWhileSetAtExecute);
376381

377382
Elem = Elements[1];
378383
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
@@ -382,6 +387,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
382387
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 4u);
383388
ASSERT_EQ(std::get<RootDescriptor>(Elem).Visibility,
384389
ShaderVisibility::Geometry);
390+
ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::ValidFlags);
385391

386392
Elem = Elements[2];
387393
ASSERT_TRUE(std::holds_alternative<RootDescriptor>(Elem));
@@ -390,6 +396,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
390396
ASSERT_EQ(std::get<RootDescriptor>(Elem).Reg.Number, 34893247u);
391397
ASSERT_EQ(std::get<RootDescriptor>(Elem).Space, 0u);
392398
ASSERT_EQ(std::get<RootDescriptor>(Elem).Visibility, ShaderVisibility::Hull);
399+
ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::DataVolatile);
400+
401+
Elem = Elements[3];
402+
ASSERT_EQ(std::get<RootParam>(Elem).Reg.ViewType, RegisterType::BReg);
403+
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 0u);
404+
ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u);
405+
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::All);
406+
ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::None);
393407

394408
ASSERT_TRUE(Consumer->isSatisfied());
395409
}

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ enum class RootFlags : uint32_t {
4646
ValidFlags = 0x00000fff
4747
};
4848

49+
enum class RootDescriptorFlags : unsigned {
50+
None = 0,
51+
DataVolatile = 0x2,
52+
DataStaticWhileSetAtExecute = 0x4,
53+
DataStatic = 0x8,
54+
ValidFlags = 0xe,
55+
};
56+
4957
enum class DescriptorRangeFlags : unsigned {
5058
None = 0,
5159
DescriptorsVolatile = 0x1,
@@ -92,6 +100,23 @@ struct RootDescriptor {
92100
Register Reg;
93101
uint32_t Space = 0;
94102
ShaderVisibility Visibility = ShaderVisibility::All;
103+
RootDescriptorFlags Flags;
104+
105+
void setDefaultFlags() {
106+
assert(Type != ParamType::Sampler &&
107+
"Sampler is not a valid type of ParamType");
108+
switch (Type) {
109+
case ParamType::CBuffer:
110+
case ParamType::SRV:
111+
Flags = RootDescriptorFlags::DataStaticWhileSetAtExecute;
112+
break;
113+
case ParamType::UAV:
114+
Flags = RootDescriptorFlags::DataVolatile;
115+
break;
116+
case ParamType::Sampler:
117+
break;
118+
}
119+
}
95120
};
96121

97122
// Models the end of a descriptor table and stores its visibility

0 commit comments

Comments
 (0)