-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[Sema][Parse][HLSL] Implement front-end rootsignature validations #156754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clang @llvm/pr-subscribers-hlsl Author: Finn Plummer (inbelic) ChangesThis pr implements the following validations:
Resolves: #153868. Full diff: https://github.com/llvm/llvm-project/pull/156754.diff 9 Files Affected:
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index c934fed2c7462..8bb47e3a4d46d 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -13184,6 +13184,9 @@ def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Litera
def err_hlsl_invalid_rootsig_value : Error<"value must be in the range [%0, %1]">;
def err_hlsl_invalid_rootsig_flag : Error< "invalid flags for version 1.%0">;
+def err_hlsl_invalid_mixed_resources: Error< "sampler and non-sampler resource mixed in descriptor table">;
+def err_hlsl_appending_onto_unbound: Error<"offset appends to unbounded descriptor range">;
+def err_hlsl_offset_overflow: Error<"descriptor range offset overflows [%0, %1]">;
def subst_hlsl_format_ranges: TextSubstitution<
"%select{t|u|b|s}0[%1;%select{%3]|unbounded)}2">;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 1af72f8b1c934..7dd0c3e90886b 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -37,8 +37,18 @@ bool RootSignatureParser::parse() {
// Iterate as many RootSignatureElements as possible, until we hit the
// end of the stream
bool HadError = false;
+ bool HasRootFlags = false;
while (!peekExpectedToken(TokenKind::end_of_stream)) {
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+ if (HasRootFlags) {
+ reportDiag(diag::err_hlsl_rootsig_repeat_param)
+ << TokenKind::kw_RootFlags;
+ HadError = true;
+ skipUntilExpectedToken(RootElementKeywords);
+ continue;
+ }
+ HasRootFlags = true;
+
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Flags = parseRootFlags();
if (!Flags.has_value()) {
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1e5ec952c1ecf..4cf08eac6d171 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1331,12 +1331,47 @@ bool SemaHLSL::handleRootSignatureElements(
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
assert(UnboundClauses.size() == Table->NumClauses &&
"Number of unbound elements must match the number of clauses");
+ bool HasSampler = false;
+ bool HasNonSampler = false;
+ uint32_t Offset = 0;
for (const auto &[Clause, ClauseElem] : UnboundClauses) {
- uint32_t LowerBound(Clause->Reg.Number);
+ SourceLocation Loc = RootSigElem.getLocation();
+ if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
+ HasSampler = true;
+ else
+ HasNonSampler = true;
+
+ if (HasSampler && HasNonSampler)
+ Diag(Loc, diag::err_hlsl_invalid_mixed_resources);
+
// Relevant error will have already been reported above and needs to be
- // fixed before we can conduct range analysis, so shortcut error return
+ // fixed before we can conduct further analysis, so shortcut error
+ // return
if (Clause->NumDescriptors == 0)
return true;
+
+ if (Clause->Offset !=
+ llvm::hlsl::rootsig::DescriptorTableOffsetAppend) {
+ // Manually specified the offset
+ Offset = Clause->Offset;
+ }
+
+ uint64_t NextOffset =
+ llvm::hlsl::rootsig::nextOffset(Offset, Clause->NumDescriptors);
+
+ if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) {
+ // Trying to append onto unbound offset
+ Diag(Loc, diag::err_hlsl_appending_onto_unbound);
+ } else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(NextOffset -
+ 1)) {
+ // Upper bound overflows maximum offset
+ Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << NextOffset - 1;
+ }
+
+ Offset = uint32_t(NextOffset);
+
+ // Compute the register bounds and track resource binding
+ uint32_t LowerBound(Clause->Reg.Number);
uint32_t UpperBound = Clause->NumDescriptors == ~0u
? ~0u
: LowerBound + Clause->NumDescriptors - 1;
diff --git a/clang/test/SemaHLSL/RootSignature-err.hlsl b/clang/test/SemaHLSL/RootSignature-err.hlsl
index ccfa093baeb87..89c684cd8d11f 100644
--- a/clang/test/SemaHLSL/RootSignature-err.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-err.hlsl
@@ -179,7 +179,7 @@ void basic_validation_3() {}
// expected-error@+2 {{value must be in the range [1, 4294967294]}}
// expected-error@+1 {{value must be in the range [1, 4294967294]}}
-[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0), Sampler(s0, numDescriptors = 0))")]
+[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0)), DescriptorTable(Sampler(s0, numDescriptors = 0))")]
void basic_validation_4() {}
// expected-error@+2 {{value must be in the range [0, 16]}}
@@ -189,4 +189,8 @@ void basic_validation_5() {}
// expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
[RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
-void basic_validation_6() {}
\ No newline at end of file
+void basic_validation_6() {}
+
+// expected-error@+1 {{sampler and non-sampler resource mixed in descriptor table}}
+[RootSignature("DescriptorTable(Sampler(s0), CBV(b0))")]
+void mixed_resource_table() {}
diff --git a/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl b/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
index fd098b01cc723..2d025d0e6e5ce 100644
--- a/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
@@ -117,3 +117,28 @@ void bad_root_signature_14() {}
// expected-note@+1 {{overlapping resource range here}}
[RootSignature(DuplicatesRootSignature)]
void valid_root_signature_15() {}
+
+#define AppendingToUnbound \
+ "DescriptorTable(CBV(b1, numDescriptors = unbounded), CBV(b0))"
+
+// expected-error@+1 {{offset appends to unbounded descriptor range}}
+[RootSignature(AppendingToUnbound)]
+void append_to_unbound_signature() {}
+
+#define DirectOffsetOverflow \
+ "DescriptorTable(CBV(b0, offset = 4294967294 , numDescriptors = 6))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967294, 4294967299]}}
+[RootSignature(DirectOffsetOverflow)]
+void direct_offset_overflow_signature() {}
+
+#define AppendOffsetOverflow \
+ "DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 7))"
+
+// expected-error@+1 {{descriptor range offset overflows [4294967293, 4294967299]}}
+[RootSignature(AppendOffsetOverflow)]
+void append_offset_overflow_signature() {}
+
+// expected-error@+1 {{descriptor range offset overflows [4294967292, 4294967296]}}
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292, numDescriptors = 5))")]
+void offset_() {}
diff --git a/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl b/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl
index 09a1110b0fbc1..10e7215eccf6e 100644
--- a/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl
+++ b/clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl
@@ -22,3 +22,6 @@ void valid_root_signature_5() {}
[RootSignature("DescriptorTable(SRV(t5), UAV(u5, numDescriptors=2))")]
void valid_root_signature_6() {}
+
+[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 3))")]
+void valid_root_signature_7() {}
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 44c0978a243bc..9b9f5dd8a63bb 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -501,8 +501,6 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
using llvm::dxbc::RootFlags;
const llvm::StringLiteral Source = R"cc(
- RootFlags(),
- RootFlags(0),
RootFlags(
deny_domain_shader_root_access |
deny_pixel_shader_root_access |
@@ -533,18 +531,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
ASSERT_FALSE(Parser.parse());
auto Elements = Parser.getElements();
- ASSERT_EQ(Elements.size(), 3u);
+ ASSERT_EQ(Elements.size(), 1u);
RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
- ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
- Elem = Elements[1].getElement();
- ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
- ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
-
- Elem = Elements[2].getElement();
- ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
auto ValidRootFlags = RootFlags::AllowInputAssemblerInputLayout |
RootFlags::DenyVertexShaderRootAccess |
RootFlags::DenyHullShaderRootAccess |
@@ -562,6 +552,64 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyRootFlagsTest) {
+ using llvm::dxbc::RootFlags;
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(),
+ )cc";
+
+ auto Ctx = createMinimalASTContext();
+ StringLiteral *Signature = wrapSource(Ctx, Source);
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+
+ hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ auto Elements = Parser.getElements();
+ ASSERT_EQ(Elements.size(), 1u);
+
+ RootElement Elem = Elements[0].getElement();
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
+TEST_F(ParseHLSLRootSignatureTest, ValidParseZeroRootFlagsTest) {
+ using llvm::dxbc::RootFlags;
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(0),
+ )cc";
+
+ auto Ctx = createMinimalASTContext();
+ StringLiteral *Signature = wrapSource(Ctx, Source);
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+
+ hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+ // Test no diagnostics produced
+ Consumer->setNoDiag();
+
+ ASSERT_FALSE(Parser.parse());
+
+ auto Elements = Parser.getElements();
+ ASSERT_EQ(Elements.size(), 1u);
+
+ RootElement Elem = Elements[0].getElement();
+ ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
+ ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
using llvm::dxbc::RootDescriptorFlags;
const llvm::StringLiteral Source = R"cc(
@@ -1658,4 +1706,27 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidDescriptorRangeFlagsValueTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}
+TEST_F(ParseHLSLRootSignatureTest, InvalidMultipleRootFlagsTest) {
+ // This test will check that an error is produced when there are multiple
+ // root flags provided
+ const llvm::StringLiteral Source = R"cc(
+ RootFlags(DENY_VERTEX_SHADER_ROOT_ACCESS),
+ RootFlags(DENY_PIXEL_SHADER_ROOT_ACCESS)
+ )cc";
+
+ auto Ctx = createMinimalASTContext();
+ StringLiteral *Signature = wrapSource(Ctx, Source);
+
+ TrivialModuleLoader ModLoader;
+ auto PP = createPP(Source, ModLoader);
+
+ hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);
+
+ // Test correct diagnostic produced
+ Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
+ ASSERT_TRUE(Parser.parse());
+
+ ASSERT_TRUE(Consumer->isSatisfied());
+}
+
} // anonymous namespace
diff --git a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
index fde32a1fff591..5ffd31ecb2650 100644
--- a/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
+++ b/llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
@@ -41,6 +41,10 @@ LLVM_ABI bool verifyComparisonFunc(uint32_t ComparisonFunc);
LLVM_ABI bool verifyBorderColor(uint32_t BorderColor);
LLVM_ABI bool verifyLOD(float LOD);
+LLVM_ABI bool verifyBoundOffset(uint32_t Offset);
+LLVM_ABI bool verifyNoOverflowedOffset(uint64_t Offset);
+LLVM_ABI uint64_t nextOffset(uint32_t Offset, uint32_t Size);
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index 72308a3de5fd4..f19354ceb6072 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -180,6 +180,21 @@ bool verifyBorderColor(uint32_t BorderColor) {
bool verifyLOD(float LOD) { return !std::isnan(LOD); }
+bool verifyBoundOffset(uint32_t Offset) {
+ return Offset != NumDescriptorsUnbounded;
+}
+
+bool verifyNoOverflowedOffset(uint64_t Offset) {
+ return Offset <= std::numeric_limits<uint32_t>::max();
+}
+
+uint64_t nextOffset(uint32_t Offset, uint32_t Size) {
+ if (Size == NumDescriptorsUnbounded)
+ return NumDescriptorsUnbounded;
+
+ return uint64_t(Offset) + uint64_t(Size);
+}
+
} // namespace rootsig
} // namespace hlsl
} // namespace llvm
|
if (HasRootFlags) { | ||
reportDiag(diag::err_hlsl_rootsig_repeat_param) | ||
<< TokenKind::kw_RootFlags; | ||
HadError = true; | ||
skipUntilExpectedToken(RootElementKeywords); | ||
continue; | ||
} | ||
HasRootFlags = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This doesn't seem documented in the description of this PR, is it checking if we are defining multiple RootFlags
in a root signature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should correspond to point 4 of the pr description. However, it isn't described in the original issue as I just found this discrepancy when trying to implement the other checks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, small nit
This pr implements the following validations:
RootFlags
parameter is providedResolves: #153868.