Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -13195,6 +13195,9 @@ def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Litera

def err_hlsl_invalid_rootsig_value : Error<"value must be in the range [%0, %1]">;
def err_hlsl_invalid_rootsig_flag : Error< "invalid flags for version 1.%0">;
def err_hlsl_invalid_mixed_resources: Error< "sampler and non-sampler resource mixed in descriptor table">;
def err_hlsl_appending_onto_unbound: Error<"offset appends to unbounded descriptor range">;
def err_hlsl_offset_overflow: Error<"descriptor range offset overflows [%0, %1]">;

def subst_hlsl_format_ranges: TextSubstitution<
"%select{t|u|b|s}0[%1;%select{%3]|unbounded)}2">;
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,18 @@ bool RootSignatureParser::parse() {
// Iterate as many RootSignatureElements as possible, until we hit the
// end of the stream
bool HadError = false;
bool HasRootFlags = false;
while (!peekExpectedToken(TokenKind::end_of_stream)) {
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
if (HasRootFlags) {
reportDiag(diag::err_hlsl_rootsig_repeat_param)
<< TokenKind::kw_RootFlags;
HadError = true;
skipUntilExpectedToken(RootElementKeywords);
continue;
}
HasRootFlags = true;
Comment on lines +43 to +50
Copy link
Contributor

@joaosaffran joaosaffran Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This doesn't seem documented in the description of this PR, is it checking if we are defining multiple RootFlags in a root signature?

Copy link
Contributor Author

@inbelic inbelic Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should correspond to point 4 of the pr description. However, it isn't described in the original issue as I just found this discrepancy when trying to implement the other checks


SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Flags = parseRootFlags();
if (!Flags.has_value()) {
Expand Down
40 changes: 38 additions & 2 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,12 +1331,48 @@ bool SemaHLSL::handleRootSignatureElements(
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
assert(UnboundClauses.size() == Table->NumClauses &&
"Number of unbound elements must match the number of clauses");
bool HasSampler = false;
bool HasNonSampler = false;
uint32_t Offset = 0;
for (const auto &[Clause, ClauseElem] : UnboundClauses) {
uint32_t LowerBound(Clause->Reg.Number);
SourceLocation Loc = ClauseElem->getLocation();
if (Clause->Type == llvm::dxil::ResourceClass::Sampler)
HasSampler = true;
else
HasNonSampler = true;

if (HasSampler && HasNonSampler)
Diag(Loc, diag::err_hlsl_invalid_mixed_resources);

// Relevant error will have already been reported above and needs to be
// fixed before we can conduct range analysis, so shortcut error return
// fixed before we can conduct further analysis, so shortcut error
// return
if (Clause->NumDescriptors == 0)
return true;

if (Clause->Offset !=
llvm::hlsl::rootsig::DescriptorTableOffsetAppend) {
// Manually specified the offset
Offset = Clause->Offset;
}

uint64_t RangeBound = llvm::hlsl::rootsig::computeRangeBound(
Offset, Clause->NumDescriptors);

if (!llvm::hlsl::rootsig::verifyBoundOffset(Offset)) {
// Trying to append onto unbound offset
Diag(Loc, diag::err_hlsl_appending_onto_unbound);
} else if (!llvm::hlsl::rootsig::verifyNoOverflowedOffset(RangeBound)) {
// Upper bound overflows maximum offset
Diag(Loc, diag::err_hlsl_offset_overflow) << Offset << RangeBound;
}

Offset = RangeBound == llvm::hlsl::rootsig::NumDescriptorsUnbounded
? uint32_t(RangeBound)
: uint32_t(RangeBound + 1);

// Compute the register bounds and track resource binding
uint32_t LowerBound(Clause->Reg.Number);
uint32_t UpperBound = Clause->NumDescriptors == ~0u
? ~0u
: LowerBound + Clause->NumDescriptors - 1;
Expand Down
8 changes: 6 additions & 2 deletions clang/test/SemaHLSL/RootSignature-err.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void basic_validation_3() {}

// expected-error@+2 {{value must be in the range [1, 4294967294]}}
// expected-error@+1 {{value must be in the range [1, 4294967294]}}
[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0), Sampler(s0, numDescriptors = 0))")]
[RootSignature("DescriptorTable(UAV(u0, numDescriptors = 0)), DescriptorTable(Sampler(s0, numDescriptors = 0))")]
void basic_validation_4() {}

// expected-error@+2 {{value must be in the range [0, 16]}}
Expand All @@ -189,4 +189,8 @@ void basic_validation_5() {}

// expected-error@+1 {{value must be in the range [-16.00, 15.99]}}
[RootSignature("StaticSampler(s0, mipLODBias = 15.990001)")]
void basic_validation_6() {}
void basic_validation_6() {}

// expected-error@+1 {{sampler and non-sampler resource mixed in descriptor table}}
[RootSignature("DescriptorTable(Sampler(s0), CBV(b0))")]
void mixed_resource_table() {}
25 changes: 25 additions & 0 deletions clang/test/SemaHLSL/RootSignature-resource-ranges-err.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,28 @@ void bad_root_signature_14() {}
// expected-note@+1 {{overlapping resource range here}}
[RootSignature(DuplicatesRootSignature)]
void valid_root_signature_15() {}

#define AppendingToUnbound \
"DescriptorTable(CBV(b1, numDescriptors = unbounded), CBV(b0))"

// expected-error@+1 {{offset appends to unbounded descriptor range}}
[RootSignature(AppendingToUnbound)]
void append_to_unbound_signature() {}

#define DirectOffsetOverflow \
"DescriptorTable(CBV(b0, offset = 4294967294 , numDescriptors = 6))"

// expected-error@+1 {{descriptor range offset overflows [4294967294, 4294967299]}}
[RootSignature(DirectOffsetOverflow)]
void direct_offset_overflow_signature() {}

#define AppendOffsetOverflow \
"DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 7))"

// expected-error@+1 {{descriptor range offset overflows [4294967293, 4294967299]}}
[RootSignature(AppendOffsetOverflow)]
void append_offset_overflow_signature() {}

// expected-error@+1 {{descriptor range offset overflows [4294967292, 4294967296]}}
[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292, numDescriptors = 5))")]
void offset_() {}
3 changes: 3 additions & 0 deletions clang/test/SemaHLSL/RootSignature-resource-ranges.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ void valid_root_signature_5() {}

[RootSignature("DescriptorTable(SRV(t5), UAV(u5, numDescriptors=2))")]
void valid_root_signature_6() {}

[RootSignature("DescriptorTable(CBV(b0, offset = 4294967292), CBV(b1, numDescriptors = 3))")]
void valid_root_signature_7() {}
93 changes: 82 additions & 11 deletions clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,6 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
using llvm::dxbc::RootFlags;
const llvm::StringLiteral Source = R"cc(
RootFlags(),
RootFlags(0),
RootFlags(
deny_domain_shader_root_access |
deny_pixel_shader_root_access |
Expand Down Expand Up @@ -533,18 +531,10 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
ASSERT_FALSE(Parser.parse());

auto Elements = Parser.getElements();
ASSERT_EQ(Elements.size(), 3u);
ASSERT_EQ(Elements.size(), 1u);

RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);

Elem = Elements[1].getElement();
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);

Elem = Elements[2].getElement();
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
auto ValidRootFlags = RootFlags::AllowInputAssemblerInputLayout |
RootFlags::DenyVertexShaderRootAccess |
RootFlags::DenyHullShaderRootAccess |
Expand All @@ -562,6 +552,64 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootFlagsTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidParseEmptyRootFlagsTest) {
using llvm::dxbc::RootFlags;
const llvm::StringLiteral Source = R"cc(
RootFlags(),
)cc";

auto Ctx = createMinimalASTContext();
StringLiteral *Signature = wrapSource(Ctx, Source);

TrivialModuleLoader ModLoader;
auto PP = createPP(Source, ModLoader);

hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);

// Test no diagnostics produced
Consumer->setNoDiag();

ASSERT_FALSE(Parser.parse());

auto Elements = Parser.getElements();
ASSERT_EQ(Elements.size(), 1u);

RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);

ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidParseZeroRootFlagsTest) {
using llvm::dxbc::RootFlags;
const llvm::StringLiteral Source = R"cc(
RootFlags(0),
)cc";

auto Ctx = createMinimalASTContext();
StringLiteral *Signature = wrapSource(Ctx, Source);

TrivialModuleLoader ModLoader;
auto PP = createPP(Source, ModLoader);

hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);

// Test no diagnostics produced
Consumer->setNoDiag();

ASSERT_FALSE(Parser.parse());

auto Elements = Parser.getElements();
ASSERT_EQ(Elements.size(), 1u);

RootElement Elem = Elements[0].getElement();
ASSERT_TRUE(std::holds_alternative<RootFlags>(Elem));
ASSERT_EQ(std::get<RootFlags>(Elem), RootFlags::None);

ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, ValidParseRootDescriptorsTest) {
using llvm::dxbc::RootDescriptorFlags;
const llvm::StringLiteral Source = R"cc(
Expand Down Expand Up @@ -1658,4 +1706,27 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidDescriptorRangeFlagsValueTest) {
ASSERT_TRUE(Consumer->isSatisfied());
}

TEST_F(ParseHLSLRootSignatureTest, InvalidMultipleRootFlagsTest) {
// This test will check that an error is produced when there are multiple
// root flags provided
const llvm::StringLiteral Source = R"cc(
RootFlags(DENY_VERTEX_SHADER_ROOT_ACCESS),
RootFlags(DENY_PIXEL_SHADER_ROOT_ACCESS)
)cc";

auto Ctx = createMinimalASTContext();
StringLiteral *Signature = wrapSource(Ctx, Source);

TrivialModuleLoader ModLoader;
auto PP = createPP(Source, ModLoader);

hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Signature, *PP);

// Test correct diagnostic produced
Consumer->setExpected(diag::err_hlsl_rootsig_repeat_param);
ASSERT_TRUE(Parser.parse());

ASSERT_TRUE(Consumer->isSatisfied());
}

} // anonymous namespace
4 changes: 4 additions & 0 deletions llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ LLVM_ABI bool verifyComparisonFunc(uint32_t ComparisonFunc);
LLVM_ABI bool verifyBorderColor(uint32_t BorderColor);
LLVM_ABI bool verifyLOD(float LOD);

LLVM_ABI bool verifyBoundOffset(uint32_t Offset);
LLVM_ABI bool verifyNoOverflowedOffset(uint64_t Offset);
LLVM_ABI uint64_t computeRangeBound(uint32_t Offset, uint32_t Size);

} // namespace rootsig
} // namespace hlsl
} // namespace llvm
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,22 @@ bool verifyBorderColor(uint32_t BorderColor) {

bool verifyLOD(float LOD) { return !std::isnan(LOD); }

bool verifyBoundOffset(uint32_t Offset) {
return Offset != NumDescriptorsUnbounded;
}

bool verifyNoOverflowedOffset(uint64_t Offset) {
return Offset <= std::numeric_limits<uint32_t>::max();
}

uint64_t computeRangeBound(uint32_t Offset, uint32_t Size) {
assert(0 < Size && "Must be a non-empty range");
if (Size == NumDescriptorsUnbounded)
return NumDescriptorsUnbounded;

return uint64_t(Offset) + uint64_t(Size) - 1;
}

} // namespace rootsig
} // namespace hlsl
} // namespace llvm