Skip to content

Commit ab39042

Browse files
committed
[HLSL][RootSignature] Add parsing of flags to RootParam
- defines RootDescriptorFlags in-memory representation - defines parseRootDescriptorFlags to be DXC compatible. This is why we support multiple `|` flags even validation will assert that only one flag is set... - add unit tests to demonstrate functionality
1 parent 0ae2c55 commit ab39042

File tree

4 files changed

+105
-3
lines changed

4 files changed

+105
-3
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class RootSignatureParser {
9393
std::optional<llvm::hlsl::rootsig::Register> Reg;
9494
std::optional<uint32_t> Space;
9595
std::optional<llvm::hlsl::rootsig::ShaderVisibility> Visibility;
96+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags> Flags;
9697
};
9798
std::optional<ParsedRootParamParams>
9899
parseRootParamParams(RootSignatureToken::Kind RegType);
@@ -113,6 +114,8 @@ class RootSignatureParser {
113114

114115
/// Parsing methods of various enums
115116
std::optional<llvm::hlsl::rootsig::ShaderVisibility> parseShaderVisibility();
117+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
118+
parseRootDescriptorFlags();
116119
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
117120
parseDescriptorRangeFlags();
118121

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ std::optional<RootParam> RootSignatureParser::parseRootParam() {
193193
ExpectedReg = TokenKind::uReg;
194194
break;
195195
}
196+
Param.setDefaultFlags();
196197

197198
auto Params = parseRootParamParams(ExpectedReg);
198199
if (!Params.has_value())
@@ -214,6 +215,9 @@ std::optional<RootParam> RootSignatureParser::parseRootParam() {
214215
if (Params->Visibility.has_value())
215216
Param.Visibility = Params->Visibility.value();
216217

218+
if (Params->Flags.has_value())
219+
Param.Flags = Params->Flags.value();
220+
217221
if (consumeExpectedToken(TokenKind::pu_r_paren,
218222
diag::err_hlsl_unexpected_end_of_params,
219223
/*param of=*/TokenKind::kw_RootConstants))
@@ -475,6 +479,23 @@ RootSignatureParser::parseRootParamParams(TokenKind RegType) {
475479
return std::nullopt;
476480
Params.Visibility = Visibility;
477481
}
482+
483+
// `flags` `=` ROOT_DESCRIPTOR_FLAGS
484+
if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
485+
if (Params.Flags.has_value()) {
486+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
487+
<< CurToken.TokKind;
488+
return std::nullopt;
489+
}
490+
491+
if (consumeExpectedToken(TokenKind::pu_equal))
492+
return std::nullopt;
493+
494+
auto Flags = parseRootDescriptorFlags();
495+
if (!Flags.has_value())
496+
return std::nullopt;
497+
Params.Flags = Flags;
498+
}
478499
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
479500

480501
return Params;
@@ -654,6 +675,45 @@ RootSignatureParser::parseShaderVisibility() {
654675
return std::nullopt;
655676
}
656677

678+
std::optional<llvm::hlsl::rootsig::RootDescriptorFlags>
679+
RootSignatureParser::parseRootDescriptorFlags() {
680+
assert(CurToken.TokKind == TokenKind::pu_equal &&
681+
"Expects to only be invoked starting at given keyword");
682+
683+
// Handle the edge-case of '0' to specify no flags set
684+
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
685+
if (!verifyZeroFlag()) {
686+
getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_non_zero_flag);
687+
return std::nullopt;
688+
}
689+
return RootDescriptorFlags::None;
690+
}
691+
692+
TokenKind Expected[] = {
693+
#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) TokenKind::en_##NAME,
694+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
695+
};
696+
697+
std::optional<RootDescriptorFlags> Flags;
698+
699+
do {
700+
if (tryConsumeExpectedToken(Expected)) {
701+
switch (CurToken.TokKind) {
702+
#define ROOT_DESCRIPTOR_FLAG_ENUM(NAME, LIT) \
703+
case TokenKind::en_##NAME: \
704+
Flags = \
705+
maybeOrFlag<RootDescriptorFlags>(Flags, RootDescriptorFlags::NAME); \
706+
break;
707+
#include "clang/Lex/HLSLRootSignatureTokenKinds.def"
708+
default:
709+
llvm_unreachable("Switch for consumed enum token was not provided");
710+
}
711+
}
712+
} while (tryConsumeExpectedToken(TokenKind::pu_or));
713+
714+
return Flags;
715+
}
716+
657717
std::optional<llvm::hlsl::rootsig::DescriptorRangeFlags>
658718
RootSignatureParser::parseDescriptorRangeFlags() {
659719
assert(CurToken.TokKind == TokenKind::pu_equal &&

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,11 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
347347
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
348348
const llvm::StringLiteral Source = R"cc(
349349
CBV(b0),
350-
SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY),
351-
UAV(visibility = SHADER_VISIBILITY_HULL, u34893247)
350+
SRV(space = 4, t42, visibility = SHADER_VISIBILITY_GEOMETRY,
351+
flags = DATA_VOLATILE | DATA_STATIC | DATA_STATIC_WHILE_SET_AT_EXECUTE
352+
),
353+
UAV(visibility = SHADER_VISIBILITY_HULL, u34893247),
354+
CBV(b0, flags = 0),
352355
)cc";
353356

354357
TrivialModuleLoader ModLoader;
@@ -364,14 +367,16 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
364367

365368
ASSERT_FALSE(Parser.parse());
366369

367-
ASSERT_EQ(Elements.size(), 3u);
370+
ASSERT_EQ(Elements.size(), 4u);
368371

369372
RootElement Elem = Elements[0];
370373
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
371374
ASSERT_EQ(std::get<RootParam>(Elem).Reg.ViewType, RegisterType::BReg);
372375
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 0u);
373376
ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u);
374377
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::All);
378+
ASSERT_EQ(std::get<RootParam>(Elem).Flags,
379+
RootDescriptorFlags::DataStaticWhileSetAtExecute);
375380

376381
Elem = Elements[1];
377382
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
@@ -380,6 +385,7 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
380385
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 42u);
381386
ASSERT_EQ(std::get<RootParam>(Elem).Space, 4u);
382387
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::Geometry);
388+
ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::ValidFlags);
383389

384390
Elem = Elements[2];
385391
ASSERT_TRUE(std::holds_alternative<RootParam>(Elem));
@@ -388,6 +394,14 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootParamsTest) {
388394
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 34893247u);
389395
ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u);
390396
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::Hull);
397+
ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::DataVolatile);
398+
399+
Elem = Elements[3];
400+
ASSERT_EQ(std::get<RootParam>(Elem).Reg.ViewType, RegisterType::BReg);
401+
ASSERT_EQ(std::get<RootParam>(Elem).Reg.Number, 0u);
402+
ASSERT_EQ(std::get<RootParam>(Elem).Space, 0u);
403+
ASSERT_EQ(std::get<RootParam>(Elem).Visibility, ShaderVisibility::All);
404+
ASSERT_EQ(std::get<RootParam>(Elem).Flags, RootDescriptorFlags::None);
391405

392406
ASSERT_TRUE(Consumer->isSatisfied());
393407
}

llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ enum class RootFlags : uint32_t {
4646
ValidFlags = 0x00000fff
4747
};
4848

49+
enum class RootDescriptorFlags : unsigned {
50+
None = 0,
51+
DataVolatile = 0x2,
52+
DataStaticWhileSetAtExecute = 0x4,
53+
DataStatic = 0x8,
54+
ValidFlags = 0xe,
55+
};
56+
4957
enum class DescriptorRangeFlags : unsigned {
5058
None = 0,
5159
DescriptorsVolatile = 0x1,
@@ -91,6 +99,23 @@ struct RootParam {
9199
Register Reg;
92100
uint32_t Space = 0;
93101
ShaderVisibility Visibility = ShaderVisibility::All;
102+
RootDescriptorFlags Flags;
103+
104+
void setDefaultFlags() {
105+
assert(Type != ParamType::Sampler &&
106+
"Sampler is not a valid type of ParamType");
107+
switch (Type) {
108+
case ParamType::CBuffer:
109+
case ParamType::SRV:
110+
Flags = RootDescriptorFlags::DataStaticWhileSetAtExecute;
111+
break;
112+
case ParamType::UAV:
113+
Flags = RootDescriptorFlags::DataVolatile;
114+
break;
115+
case ParamType::Sampler:
116+
break;
117+
}
118+
}
94119
};
95120

96121
// Models the end of a descriptor table and stores its visibility

0 commit comments

Comments
 (0)