Skip to content

Commit 6a9d43a

Browse files
authored
[Sema][Parse][HLSL] Implement front-end rootsignature validations (#156754)
This pr implements the following validations: 1. Check that descriptor tables don't mix Sample and non-Sampler resources 2. Ensure that descriptor ranges don't append onto an unbounded range 3. Ensure that descriptor ranges don't overflow 4. Adds a missing validation to ensure that only a single `RootFlags` parameter is provided Resolves: #153868.
1 parent 11c2ffd commit 6a9d43a

File tree

9 files changed

+187
-15
lines changed

9 files changed

+187
-15
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13202,6 +13202,9 @@ def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Litera
1320213202

1320313203
def err_hlsl_invalid_rootsig_value : Error<"value must be in the range [%0, %1]">;
1320413204
def err_hlsl_invalid_rootsig_flag : Error< "invalid flags for version 1.%0">;
13205+
def err_hlsl_invalid_mixed_resources: Error< "sampler and non-sampler resource mixed in descriptor table">;
13206+
def err_hlsl_appending_onto_unbound: Error<"offset appends to unbounded descriptor range">;
13207+
def err_hlsl_offset_overflow: Error<"descriptor range offset overflows [%0, %1]">;
1320513208

1320613209
def subst_hlsl_format_ranges: TextSubstitution<
1320713210
"%select{t|u|b|s}0[%1;%select{%3]|unbounded)}2">;

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,18 @@ bool RootSignatureParser::parse() {
3838
// Iterate as many RootSignatureElements as possible, until we hit the
3939
// end of the stream
4040
bool HadError = false;
41+
bool HasRootFlags = false;
4142
while (!peekExpectedToken(TokenKind::end_of_stream)) {
4243
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
44+
if (HasRootFlags) {
45+
reportDiag(diag::err_hlsl_rootsig_repeat_param)
46+
<< TokenKind::kw_RootFlags;
47+
HadError = true;
48+
skipUntilExpectedToken(RootElementKeywords);
49+
continue;
50+
}
51+
HasRootFlags = true;
52+
4353
SourceLocation ElementLoc = getTokenLocation(CurToken);
4454
auto Flags = parseRootFlags();
4555
if (!Flags.has_value()) {

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,12 +1359,48 @@ bool SemaHLSL::handleRootSignatureElements(
13591359
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
13601360
assert(UnboundClauses.size() == Table->NumClauses &&
13611361
"Number of unbound elements must match the number of clauses");
1362+
bool HasAnySampler = false;
1363+
bool HasAnyNonSampler = false;
1364+
uint32_t Offset = 0;
13621365
for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1363-
uint32_t LowerBound(Clause->Reg.Number);
1366+
SourceLocation Loc = ClauseElem->getLocation();
1367+
if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
1368+
HasAnySampler = true;
1369+
else
1370+
HasAnyNonSampler = true;
1371+
1372+
if (HasAnySampler && HasAnyNonSampler)
1373+
Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
1374+
13641375
// Relevant error will have already been reported above and needs to be
1365-
// fixed before we can conduct range analysis, so shortcut error return
1376+
// fixed before we can conduct further analysis, so shortcut error
1377+
// return
13661378
if (Clause->NumDescriptors == 0)
13671379
return true;
1380+
1381+
if (Clause->Offset !=
1382+
llvm::hlsl::rootsig::DescriptorTableOffsetAppend) {
1383+
// Manually specified the offset
1384+
Offset = Clause->Offset;
1385+
}
1386+
1387+
uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
1388+
Offset, Clause->NumDescriptors);
1389+
1390+
if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) {
1391+
// Trying to append onto unbound offset
1392+
Diag(Loc, diag::err_hlsl_appending_onto_unbound);
1393+
} else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound)) {
1394+
// Upper bound overflows maximum offset
1395+
Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
1396+
}
1397+
1398+
Offset = RangeBound == llvm::hlsl::rootsig::NumDescriptorsUnbounded
1399+
? uint32_t(RangeBound)
1400+
: uint32_t(RangeBound + 1);
1401+
1402+
// Compute the register bounds and track resource binding
1403+
uint32_t LowerBound(Clause->Reg.Number);
13681404
uint32_t UpperBound = Clause->NumDescriptors == ~0u
13691405
? ~0u
13701406
: LowerBound + Clause->NumDescriptors - 1;

clang/test/SemaHLSL/RootSignature-err.hlsl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void basic_validation_3() {}
179179

180180
// expected-error@+2 {{value must be in the range [1, 4294967294]}}
181181
// expected-error@+1 {{value must be in the range [1, 4294967294]}}
182-
[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0), Sampler(s0, numDescriptors = 0))")]
182+
[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0)), DescriptorTable(Sampler(s0, numDescriptors = 0))")]
183183
void basic_validation_4() {}
184184

185185
// expected-error@+2 {{value must be in the range [0, 16]}}
@@ -189,4 +189,8 @@ void basic_validation_5() {}
189189

190190
// expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
191191
[RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
192-
void basic_validation_6() {}
192+
void basic_validation_6() {}
193+
194+
// expected-error@+1 {{sampler and non-sampler resource mixed in descriptor table}}
195+
[RootSignature("DescriptorTable(Sampler(s0), CBV(b0))")]
196+
void mixed_resource_table() {}

clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,28 @@ void bad_root_signature_14() {}
117117
// expected-note@+1 {{overlapping resource range here}}
118118
[RootSignature(DuplicatesRootSignature)]
119119
void valid_root_signature_15() {}
120+
121+
#define AppendingToUnbound \
122+
"DescriptorTable(CBV(b1, numDescriptors = unbounded), CBV(b0))"
123+
124+
// expected-error@+1 {{offset appends to unbounded descriptor range}}
125+
[RootSignature(AppendingToUnbound)]
126+
void append_to_unbound_signature() {}
127+
128+
#define DirectOffsetOverflow \
129+
"DescriptorTable(CBV(b0, offset = 4294967294 , numDescriptors = 6))"
130+
131+
// expected-error@+1 {{descriptor range offset overflows [4294967294, 4294967299]}}
132+
[RootSignature(DirectOffsetOverflow)]
133+
void direct_offset_overflow_signature() {}
134+
135+
#define AppendOffsetOverflow \
136+
"DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 7))"
137+
138+
// expected-error@+1 {{descriptor range offset overflows [4294967293, 4294967299]}}
139+
[RootSignature(AppendOffsetOverflow)]
140+
void append_offset_overflow_signature() {}
141+
142+
// expected-error@+1 {{descriptor range offset overflows [4294967292, 4294967296]}}
143+
[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292, numDescriptors = 5))")]
144+
void offset_() {}

clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,6 @@ void valid_root_signature_5() {}
2222

2323
[RootSignature("DescriptorTable(SRV(t5), UAV(u5, numDescriptors=2))")]
2424
void valid_root_signature_6() {}
25+
26+
[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 3))")]
27+
void valid_root_signature_7() {}

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,6 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
501501
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
502502
using llvm::dxbc::RootFlags;
503503
const llvm::StringLiteral Source = R"cc(
504-
RootFlags(),
505-
RootFlags(0),
506504
RootFlags(
507505
deny_domain_shader_root_access |
508506
deny_pixel_shader_root_access |
@@ -533,18 +531,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
533531
ASSERT_FALSE(Parser.parse());
534532

535533
auto Elements = Parser.getElements();
536-
ASSERT_EQ(Elements.size(), 3u);
534+
ASSERT_EQ(Elements.size(), 1u);
537535

538536
RootElement Elem = Elements[0].getElement();
539537
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
540-
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
541-
542-
Elem = Elements[1].getElement();
543-
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
544-
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
545-
546-
Elem = Elements[2].getElement();
547-
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
548538
auto ValidRootFlags = RootFlags::AllowInputAssemblerInputLayout |
549539
RootFlags::DenyVertexShaderRootAccess |
550540
RootFlags::DenyHullShaderRootAccess |
@@ -562,6 +552,64 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
562552
ASSERT_TRUE(Consumer->isSatisfied());
563553
}
564554

555+
TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyRootFlagsTest) {
556+
using llvm::dxbc::RootFlags;
557+
const llvm::StringLiteral Source = R"cc(
558+
RootFlags(),
559+
)cc";
560+
561+
auto Ctx = createMinimalASTContext();
562+
StringLiteral *Signature = wrapSource(Ctx, Source);
563+
564+
TrivialModuleLoader ModLoader;
565+
auto PP = createPP(Source, ModLoader);
566+
567+
hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
568+
569+
// Test no diagnostics produced
570+
Consumer->setNoDiag();
571+
572+
ASSERT_FALSE(Parser.parse());
573+
574+
auto Elements = Parser.getElements();
575+
ASSERT_EQ(Elements.size(), 1u);
576+
577+
RootElement Elem = Elements[0].getElement();
578+
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
579+
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
580+
581+
ASSERT_TRUE(Consumer->isSatisfied());
582+
}
583+
584+
TEST_F(ParseHLSLRootSignatureTest, ValidParseZeroRootFlagsTest) {
585+
using llvm::dxbc::RootFlags;
586+
const llvm::StringLiteral Source = R"cc(
587+
RootFlags(0),
588+
)cc";
589+
590+
auto Ctx = createMinimalASTContext();
591+
StringLiteral *Signature = wrapSource(Ctx, Source);
592+
593+
TrivialModuleLoader ModLoader;
594+
auto PP = createPP(Source, ModLoader);
595+
596+
hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
597+
598+
// Test no diagnostics produced
599+
Consumer->setNoDiag();
600+
601+
ASSERT_FALSE(Parser.parse());
602+
603+
auto Elements = Parser.getElements();
604+
ASSERT_EQ(Elements.size(), 1u);
605+
606+
RootElement Elem = Elements[0].getElement();
607+
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
608+
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
609+
610+
ASSERT_TRUE(Consumer->isSatisfied());
611+
}
612+
565613
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
566614
using llvm::dxbc::RootDescriptorFlags;
567615
const llvm::StringLiteral Source = R"cc(
@@ -1658,4 +1706,27 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidDescriptorRangeFlagsValueTest) {
16581706
ASSERT_TRUE(Consumer->isSatisfied());
16591707
}
16601708

1709+
TEST_F(ParseHLSLRootSignatureTest, InvalidMultipleRootFlagsTest) {
1710+
// This test will check that an error is produced when there are multiple
1711+
// root flags provided
1712+
const llvm::StringLiteral Source = R"cc(
1713+
RootFlags(DENY_VERTEX_SHADER_ROOT_ACCESS),
1714+
RootFlags(DENY_PIXEL_SHADER_ROOT_ACCESS)
1715+
)cc";
1716+
1717+
auto Ctx = createMinimalASTContext();
1718+
StringLiteral *Signature = wrapSource(Ctx, Source);
1719+
1720+
TrivialModuleLoader ModLoader;
1721+
auto PP = createPP(Source, ModLoader);
1722+
1723+
hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
1724+
1725+
// Test correct diagnostic produced
1726+
Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
1727+
ASSERT_TRUE(Parser.parse());
1728+
1729+
ASSERT_TRUE(Consumer->isSatisfied());
1730+
}
1731+
16611732
} // anonymous namespace

llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ LLVM_ABI bool verifyComparisonFunc(uint32_t ComparisonFunc);
4141
LLVM_ABI bool verifyBorderColor(uint32_t BorderColor);
4242
LLVM_ABI bool verifyLOD(float LOD);
4343

44+
LLVM_ABI bool verifyBoundOffset(uint32_t Offset);
45+
LLVM_ABI bool verifyNoOverflowedOffset(uint64_t Offset);
46+
LLVM_ABI uint64_t computeRangeBound(uint32_t Offset, uint32_t Size);
47+
4448
} // namespace rootsig
4549
} // namespace hlsl
4650
} // namespace llvm

llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,22 @@ bool verifyBorderColor(uint32_t BorderColor) {
180180

181181
bool verifyLOD(float LOD) { return !std::isnan(LOD); }
182182

183+
bool verifyBoundOffset(uint32_t Offset) {
184+
return Offset != NumDescriptorsUnbounded;
185+
}
186+
187+
bool verifyNoOverflowedOffset(uint64_t Offset) {
188+
return Offset <= std::numeric_limits<uint32_t>::max();
189+
}
190+
191+
uint64_t computeRangeBound(uint32_t Offset, uint32_t Size) {
192+
assert(0 < Size && "Must be a non-empty range");
193+
if (Size == NumDescriptorsUnbounded)
194+
return NumDescriptorsUnbounded;
195+
196+
return uint64_t(Offset) + uint64_t(Size) - 1;
197+
}
198+
183199
} // namespace rootsig
184200
} // namespace hlsl
185201
} // namespace llvm

0 commit comments

Comments
 (0)