Skip to content

[HLSL][Sema] Use hlsl::BindingInfoBuilder instead of RangeInfo. NFC #150634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 140 additions & 123 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Frontend/HLSL/HLSLBinding.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DXILABI.h"
Expand Down Expand Up @@ -1083,6 +1084,102 @@ void SemaHLSL::ActOnFinishRootSignatureDecl(
SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
}

namespace {

struct PerVisibilityBindingChecker {
SemaHLSL *S;
// We need one builder per `llvm::dxbc::ShaderVisibility` value.
std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;

struct ElemInfo {
const hlsl::RootSignatureElement *Elem;
llvm::dxbc::ShaderVisibility Vis;
bool Diagnosed;
};
llvm::SmallVector<ElemInfo> ElemInfoMap;

PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}

void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
llvm::dxil::ResourceClass RC, uint32_t Space,
uint32_t LowerBound, uint32_t UpperBound,
const hlsl::RootSignatureElement *Elem) {
uint32_t BuilderIndex = llvm::to_underlying(Visibility);
assert(BuilderIndex < Builders.size() &&
"Not enough builders for visibility type");
Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
static_cast<const void *>(Elem));

static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
"'All' visibility must come first");
if (Visibility == llvm::dxbc::ShaderVisibility::All)
for (size_t I = 1, E = Builders.size(); I < E; ++I)
Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
static_cast<const void *>(Elem));

ElemInfoMap.push_back({Elem, Visibility, false});
}

ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
auto It = llvm::lower_bound(
ElemInfoMap, Elem,
[](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
assert(It->Elem == Elem && "Element not in map");
return *It;
}

bool checkOverlap() {
llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
return LHS.Elem < RHS.Elem;
});

bool HadOverlap = false;

using llvm::hlsl::BindingInfoBuilder;
auto ReportOverlap = [this, &HadOverlap](
const BindingInfoBuilder &Builder,
const BindingInfoBuilder::Binding &Reported) {
HadOverlap = true;

const auto *Elem =
static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
const BindingInfoBuilder::Binding &Previous =
Builder.findOverlapping(Reported);
const auto *PrevElem =
static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);

ElemInfo &Info = getInfo(Elem);
// We will have already diagnosed this binding if there's overlap in the
// "All" visibility as well as any particular visibility.
if (Info.Diagnosed)
return;
Info.Diagnosed = true;

ElemInfo &PrevInfo = getInfo(PrevElem);
llvm::dxbc::ShaderVisibility CommonVis =
Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
: Info.Vis;

this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
<< llvm::to_underlying(Reported.RC) << Reported.LowerBound
<< Reported.isUnbounded() << Reported.UpperBound
<< llvm::to_underlying(Previous.RC) << Previous.LowerBound
<< Previous.isUnbounded() << Previous.UpperBound << Reported.Space
<< CommonVis;

this->S->Diag(PrevElem->getLocation(),
diag::note_hlsl_resource_range_here);
};

for (BindingInfoBuilder &Builder : Builders)
Builder.calculateBindingInfo(ReportOverlap);

return HadOverlap;
}
};

} // end anonymous namespace

bool SemaHLSL::handleRootSignatureElements(
ArrayRef<hlsl::RootSignatureElement> Elements) {
// Define some common error handling functions
Expand Down Expand Up @@ -1171,147 +1268,67 @@ bool SemaHLSL::handleRootSignatureElements(
}
}

using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
PerVisibilityBindingChecker BindingChecker(this);
SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
const hlsl::RootSignatureElement *>>
UnboundClauses;

// 1. Collect RangeInfos
llvm::SmallVector<InfoPairT> InfoPairs;
for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
if (const auto *Descriptor =
std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Descriptor->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class =
llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
Info.Space = Descriptor->Space;
Info.Visibility = Descriptor->Visibility;
uint32_t LowerBound(Descriptor->Reg.Number);
uint32_t UpperBound(LowerBound); // inclusive range

InfoPairs.push_back({Info, &RootSigElem});
BindingChecker.trackBinding(
Descriptor->Visibility,
static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
} else if (const auto *Constants =
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Constants->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []
uint32_t LowerBound(Constants->Reg.Number);
uint32_t UpperBound(LowerBound); // inclusive range

Info.Class = llvm::dxil::ResourceClass::CBuffer;
Info.Space = Constants->Space;
Info.Visibility = Constants->Visibility;

InfoPairs.push_back({Info, &RootSigElem});
BindingChecker.trackBinding(
Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
Constants->Space, LowerBound, UpperBound, &RootSigElem);
} else if (const auto *Sampler =
std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Sampler->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class = llvm::dxil::ResourceClass::Sampler;
Info.Space = Sampler->Space;
Info.Visibility = Sampler->Visibility;
uint32_t LowerBound(Sampler->Reg.Number);
uint32_t UpperBound(LowerBound); // inclusive range

InfoPairs.push_back({Info, &RootSigElem});
BindingChecker.trackBinding(
Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
Sampler->Space, LowerBound, UpperBound, &RootSigElem);
} else if (const auto *Clause =
std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
&Elem)) {
RangeInfo Info;
Info.LowerBound = Clause->Reg.Number;
// Relevant error will have already been reported above and needs to be
// fixed before we can conduct range analysis, so shortcut error return
if (Clause->NumDescriptors == 0)
return true;
Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
? RangeInfo::Unbounded
: Info.LowerBound + Clause->NumDescriptors -
1; // use inclusive ranges []

Info.Class = Clause->Type;
Info.Space = Clause->Space;

// Note: Clause does not hold the visibility this will need to
InfoPairs.push_back({Info, &RootSigElem});
// We'll process these once we see the table element.
UnboundClauses.emplace_back(Clause, &RootSigElem);
} else if (const auto *Table =
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
// Table holds the Visibility of all owned Clauses in Table, so iterate
// owned Clauses and update their corresponding RangeInfo
assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
// The last Table->NumClauses elements of Infos are the owned Clauses
// generated RangeInfo
auto TableInfos =
MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses);
for (InfoPairT &Pair : TableInfos)
Pair.first.Visibility = Table->Visibility;
}
}

// 2. Sort with the RangeInfo <operator to prepare it for findOverlapping
llvm::sort(InfoPairs,
[](InfoPairT A, InfoPairT B) { return A.first < B.first; });

llvm::SmallVector<RangeInfo> Infos;
for (const InfoPairT &Pair : InfoPairs)
Infos.push_back(Pair.first);

// Helpers to report diagnostics
uint32_t DuplicateCounter = 0;
using ElemPair = std::pair<const hlsl::RootSignatureElement *,
const hlsl::RootSignatureElement *>;
auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter](
OverlappingRanges Overlap) -> ElemPair {
// Given we sorted the InfoPairs (and by implication) Infos, and,
// that Overlap.B is the item retrieved from the ResourceRange. Then it is
// guarenteed that Overlap.B <= Overlap.A.
//
// So we will find Overlap.B first and then continue to find Overlap.A
// after
auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B);
auto DistB = std::distance(Infos.begin(), InfoB);
auto PairB = InfoPairs.begin();
std::advance(PairB, DistB);

auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A);
// Similarily, from the property that we have sorted the RangeInfos,
// all duplicates will be processed one after the other. So
// DuplicateCounter can be re-used for each set of duplicates we
// encounter as we handle incoming errors
DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0;
auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter;
auto PairA = PairB;
std::advance(PairA, DistA);

return {PairA->second, PairB->second};
};

auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
auto Pair = GetElemPair(Overlap);
const RangeInfo *Info = Overlap.A;
const hlsl::RootSignatureElement *Elem = Pair.first;
const RangeInfo *OInfo = Overlap.B;

auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
? OInfo->Visibility
: Info->Visibility;
this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
<< llvm::to_underlying(Info->Class) << Info->LowerBound
<< /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
<< Info->UpperBound << llvm::to_underlying(OInfo->Class)
<< OInfo->LowerBound
<< /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
<< OInfo->UpperBound << Info->Space << CommonVis;

const hlsl::RootSignatureElement *OElem = Pair.second;
this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here);
};

// 3. Invoke find overlapping ranges
llvm::SmallVector<OverlappingRanges> Overlaps =
llvm::hlsl::rootsig::findOverlappingRanges(Infos);
for (OverlappingRanges Overlap : Overlaps)
ReportOverlap(Overlap);
assert(UnboundClauses.size() == Table->NumClauses &&
"Number of unbound elements must match the number of clauses");
for (const auto &[Clause, ClauseElem] : UnboundClauses) {
uint32_t LowerBound(Clause->Reg.Number);
// Relevant error will have already been reported above and needs to be
// fixed before we can conduct range analysis, so shortcut error return
if (Clause->NumDescriptors == 0)
return true;
uint32_t UpperBound = Clause->NumDescriptors == ~0u
? ~0u
: LowerBound + Clause->NumDescriptors - 1;

BindingChecker.trackBinding(
Table->Visibility,
static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
LowerBound, UpperBound, ClauseElem);
}
UnboundClauses.clear();
}
}

return Overlaps.size() != 0;
return BindingChecker.checkOverlap();
}

void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {
Expand Down
Loading