Skip to content

Commit 6b4301e

Browse files
committed
[HLSL][Sema] Use hlsl::BindingInfoBuilder instead of RangeInfo. NFC
Clean up some duplicated logic. We had two ways to do the same thing here, and BindingInfoBuilder is more flexible.
1 parent 3f066f5 commit 6b4301e

File tree

5 files changed

+140
-543
lines changed

5 files changed

+140
-543
lines changed

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 140 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "llvm/ADT/StringExtras.h"
4040
#include "llvm/ADT/StringRef.h"
4141
#include "llvm/ADT/Twine.h"
42+
#include "llvm/Frontend/HLSL/HLSLBinding.h"
4243
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
4344
#include "llvm/Support/Casting.h"
4445
#include "llvm/Support/DXILABI.h"
@@ -1083,6 +1084,102 @@ void SemaHLSL::ActOnFinishRootSignatureDecl(
10831084
SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
10841085
}
10851086

1087+
namespace {
1088+
1089+
struct PerVisibilityBindingChecker {
1090+
SemaHLSL *S;
1091+
// We need one builder per `llvm::dxbc::ShaderVisibility` value.
1092+
std::array<llvm::hlsl::BindingInfoBuilder, 8> Builders;
1093+
1094+
struct ElemInfo {
1095+
const hlsl::RootSignatureElement *Elem;
1096+
llvm::dxbc::ShaderVisibility Vis;
1097+
bool Diagnosed;
1098+
};
1099+
llvm::SmallVector<ElemInfo> ElemInfoMap;
1100+
1101+
PerVisibilityBindingChecker(SemaHLSL *S) : S(S) {}
1102+
1103+
void trackBinding(llvm::dxbc::ShaderVisibility Visibility,
1104+
llvm::dxil::ResourceClass RC, uint32_t Space,
1105+
uint32_t LowerBound, uint32_t UpperBound,
1106+
const hlsl::RootSignatureElement *Elem) {
1107+
uint32_t BuilderIndex = llvm::to_underlying(Visibility);
1108+
assert(BuilderIndex < Builders.size() &&
1109+
"Not enough builders for visibility type");
1110+
Builders[BuilderIndex].trackBinding(RC, Space, LowerBound, UpperBound,
1111+
static_cast<const void *>(Elem));
1112+
1113+
static_assert(llvm::to_underlying(llvm::dxbc::ShaderVisibility::All) == 0,
1114+
"'All' visibility must come first");
1115+
if (Visibility == llvm::dxbc::ShaderVisibility::All)
1116+
for (size_t I = 1, E = Builders.size(); I < E; ++I)
1117+
Builders[I].trackBinding(RC, Space, LowerBound, UpperBound,
1118+
static_cast<const void *>(Elem));
1119+
1120+
ElemInfoMap.push_back({Elem, Visibility, false});
1121+
}
1122+
1123+
ElemInfo &getInfo(const hlsl::RootSignatureElement *Elem) {
1124+
auto It = llvm::lower_bound(
1125+
ElemInfoMap, Elem,
1126+
[](const auto &LHS, const auto &RHS) { return LHS.Elem < RHS; });
1127+
assert(It->Elem == Elem && "Element not in map");
1128+
return *It;
1129+
}
1130+
1131+
bool checkOverlap() {
1132+
llvm::sort(ElemInfoMap, [](const auto &LHS, const auto &RHS) {
1133+
return LHS.Elem < RHS.Elem;
1134+
});
1135+
1136+
bool HadOverlap = false;
1137+
1138+
using llvm::hlsl::BindingInfoBuilder;
1139+
auto ReportOverlap = [this, &HadOverlap](
1140+
const BindingInfoBuilder &Builder,
1141+
const BindingInfoBuilder::Binding &Reported) {
1142+
HadOverlap = true;
1143+
1144+
const auto *Elem =
1145+
static_cast<const hlsl::RootSignatureElement *>(Reported.Cookie);
1146+
const BindingInfoBuilder::Binding &Previous =
1147+
Builder.findOverlapping(Reported);
1148+
const auto *PrevElem =
1149+
static_cast<const hlsl::RootSignatureElement *>(Previous.Cookie);
1150+
1151+
ElemInfo &Info = getInfo(Elem);
1152+
// We will have already diagnosed this binding if there's overlap in the
1153+
// "All" visibility as well as any particular visibility.
1154+
if (Info.Diagnosed)
1155+
return;
1156+
Info.Diagnosed = true;
1157+
1158+
ElemInfo &PrevInfo = getInfo(PrevElem);
1159+
llvm::dxbc::ShaderVisibility CommonVis =
1160+
Info.Vis == llvm::dxbc::ShaderVisibility::All ? PrevInfo.Vis
1161+
: Info.Vis;
1162+
1163+
this->S->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1164+
<< llvm::to_underlying(Reported.RC) << Reported.LowerBound
1165+
<< Reported.isUnbounded() << Reported.UpperBound
1166+
<< llvm::to_underlying(Previous.RC) << Previous.LowerBound
1167+
<< Previous.isUnbounded() << Previous.UpperBound << Reported.Space
1168+
<< CommonVis;
1169+
1170+
this->S->Diag(PrevElem->getLocation(),
1171+
diag::note_hlsl_resource_range_here);
1172+
};
1173+
1174+
for (BindingInfoBuilder &Builder : Builders)
1175+
Builder.calculateBindingInfo(ReportOverlap);
1176+
1177+
return HadOverlap;
1178+
}
1179+
};
1180+
1181+
} // end anonymous namespace
1182+
10861183
bool SemaHLSL::handleRootSignatureElements(
10871184
ArrayRef<hlsl::RootSignatureElement> Elements) {
10881185
// Define some common error handling functions
@@ -1171,147 +1268,67 @@ bool SemaHLSL::handleRootSignatureElements(
11711268
}
11721269
}
11731270

1174-
using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
1175-
using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
1176-
using InfoPairT = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
1271+
PerVisibilityBindingChecker BindingChecker(this);
1272+
SmallVector<std::pair<const llvm::hlsl::rootsig::DescriptorTableClause *,
1273+
const hlsl::RootSignatureElement *>>
1274+
UnboundClauses;
11771275

1178-
// 1. Collect RangeInfos
1179-
llvm::SmallVector<InfoPairT> InfoPairs;
11801276
for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
11811277
const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
11821278
if (const auto *Descriptor =
11831279
std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
1184-
RangeInfo Info;
1185-
Info.LowerBound = Descriptor->Reg.Number;
1186-
Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1187-
1188-
Info.Class =
1189-
llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
1190-
Info.Space = Descriptor->Space;
1191-
Info.Visibility = Descriptor->Visibility;
1280+
uint32_t LowerBound(Descriptor->Reg.Number);
1281+
uint32_t UpperBound(LowerBound); // inclusive range
11921282

1193-
InfoPairs.push_back({Info, &RootSigElem});
1283+
BindingChecker.trackBinding(
1284+
Descriptor->Visibility,
1285+
static_cast<llvm::dxil::ResourceClass>(Descriptor->Type),
1286+
Descriptor->Space, LowerBound, UpperBound, &RootSigElem);
11941287
} else if (const auto *Constants =
11951288
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
1196-
RangeInfo Info;
1197-
Info.LowerBound = Constants->Reg.Number;
1198-
Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1289+
uint32_t LowerBound(Constants->Reg.Number);
1290+
uint32_t UpperBound(LowerBound); // inclusive range
11991291

1200-
Info.Class = llvm::dxil::ResourceClass::CBuffer;
1201-
Info.Space = Constants->Space;
1202-
Info.Visibility = Constants->Visibility;
1203-
1204-
InfoPairs.push_back({Info, &RootSigElem});
1292+
BindingChecker.trackBinding(
1293+
Constants->Visibility, llvm::dxil::ResourceClass::CBuffer,
1294+
Constants->Space, LowerBound, UpperBound, &RootSigElem);
12051295
} else if (const auto *Sampler =
12061296
std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
1207-
RangeInfo Info;
1208-
Info.LowerBound = Sampler->Reg.Number;
1209-
Info.UpperBound = Info.LowerBound; // use inclusive ranges []
1210-
1211-
Info.Class = llvm::dxil::ResourceClass::Sampler;
1212-
Info.Space = Sampler->Space;
1213-
Info.Visibility = Sampler->Visibility;
1297+
uint32_t LowerBound(Sampler->Reg.Number);
1298+
uint32_t UpperBound(LowerBound); // inclusive range
12141299

1215-
InfoPairs.push_back({Info, &RootSigElem});
1300+
BindingChecker.trackBinding(
1301+
Sampler->Visibility, llvm::dxil::ResourceClass::Sampler,
1302+
Sampler->Space, LowerBound, UpperBound, &RootSigElem);
12161303
} else if (const auto *Clause =
12171304
std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
12181305
&Elem)) {
1219-
RangeInfo Info;
1220-
Info.LowerBound = Clause->Reg.Number;
1221-
// Relevant error will have already been reported above and needs to be
1222-
// fixed before we can conduct range analysis, so shortcut error return
1223-
if (Clause->NumDescriptors == 0)
1224-
return true;
1225-
Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
1226-
? RangeInfo::Unbounded
1227-
: Info.LowerBound + Clause->NumDescriptors -
1228-
1; // use inclusive ranges []
1229-
1230-
Info.Class = Clause->Type;
1231-
Info.Space = Clause->Space;
1232-
1233-
// Note: Clause does not hold the visibility this will need to
1234-
InfoPairs.push_back({Info, &RootSigElem});
1306+
// We'll process these once we see the table element.
1307+
UnboundClauses.emplace_back(Clause, &RootSigElem);
12351308
} else if (const auto *Table =
12361309
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
1237-
// Table holds the Visibility of all owned Clauses in Table, so iterate
1238-
// owned Clauses and update their corresponding RangeInfo
1239-
assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
1240-
// The last Table->NumClauses elements of Infos are the owned Clauses
1241-
// generated RangeInfo
1242-
auto TableInfos =
1243-
MutableArrayRef<InfoPairT>(InfoPairs).take_back(Table->NumClauses);
1244-
for (InfoPairT &Pair : TableInfos)
1245-
Pair.first.Visibility = Table->Visibility;
1246-
}
1247-
}
1248-
1249-
// 2. Sort with the RangeInfo <operator to prepare it for findOverlapping
1250-
llvm::sort(InfoPairs,
1251-
[](InfoPairT A, InfoPairT B) { return A.first < B.first; });
1252-
1253-
llvm::SmallVector<RangeInfo> Infos;
1254-
for (const InfoPairT &Pair : InfoPairs)
1255-
Infos.push_back(Pair.first);
1256-
1257-
// Helpers to report diagnostics
1258-
uint32_t DuplicateCounter = 0;
1259-
using ElemPair = std::pair<const hlsl::RootSignatureElement *,
1260-
const hlsl::RootSignatureElement *>;
1261-
auto GetElemPair = [&Infos, &InfoPairs, &DuplicateCounter](
1262-
OverlappingRanges Overlap) -> ElemPair {
1263-
// Given we sorted the InfoPairs (and by implication) Infos, and,
1264-
// that Overlap.B is the item retrieved from the ResourceRange. Then it is
1265-
// guarenteed that Overlap.B <= Overlap.A.
1266-
//
1267-
// So we will find Overlap.B first and then continue to find Overlap.A
1268-
// after
1269-
auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B);
1270-
auto DistB = std::distance(Infos.begin(), InfoB);
1271-
auto PairB = InfoPairs.begin();
1272-
std::advance(PairB, DistB);
1273-
1274-
auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A);
1275-
// Similarily, from the property that we have sorted the RangeInfos,
1276-
// all duplicates will be processed one after the other. So
1277-
// DuplicateCounter can be re-used for each set of duplicates we
1278-
// encounter as we handle incoming errors
1279-
DuplicateCounter = InfoA == InfoB ? DuplicateCounter + 1 : 0;
1280-
auto DistA = std::distance(InfoB, InfoA) + DuplicateCounter;
1281-
auto PairA = PairB;
1282-
std::advance(PairA, DistA);
1283-
1284-
return {PairA->second, PairB->second};
1285-
};
1286-
1287-
auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
1288-
auto Pair = GetElemPair(Overlap);
1289-
const RangeInfo *Info = Overlap.A;
1290-
const hlsl::RootSignatureElement *Elem = Pair.first;
1291-
const RangeInfo *OInfo = Overlap.B;
1292-
1293-
auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
1294-
? OInfo->Visibility
1295-
: Info->Visibility;
1296-
this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
1297-
<< llvm::to_underlying(Info->Class) << Info->LowerBound
1298-
<< /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
1299-
<< Info->UpperBound << llvm::to_underlying(OInfo->Class)
1300-
<< OInfo->LowerBound
1301-
<< /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
1302-
<< OInfo->UpperBound << Info->Space << CommonVis;
1303-
1304-
const hlsl::RootSignatureElement *OElem = Pair.second;
1305-
this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here);
1306-
};
1307-
1308-
// 3. Invoke find overlapping ranges
1309-
llvm::SmallVector<OverlappingRanges> Overlaps =
1310-
llvm::hlsl::rootsig::findOverlappingRanges(Infos);
1311-
for (OverlappingRanges Overlap : Overlaps)
1312-
ReportOverlap(Overlap);
1310+
assert(UnboundClauses.size() == Table->NumClauses &&
1311+
"Wrong number of clauses in table?");
1312+
for (const auto &[Clause, ClauseElem] : UnboundClauses) {
1313+
uint32_t LowerBound(Clause->Reg.Number);
1314+
// Relevant error will have already been reported above and needs to be
1315+
// fixed before we can conduct range analysis, so shortcut error return
1316+
if (Clause->NumDescriptors == 0)
1317+
return true;
1318+
uint32_t UpperBound = Clause->NumDescriptors == ~0u
1319+
? ~0u
1320+
: LowerBound + Clause->NumDescriptors - 1;
1321+
1322+
BindingChecker.trackBinding(
1323+
Table->Visibility,
1324+
static_cast<llvm::dxil::ResourceClass>(Clause->Type), Clause->Space,
1325+
LowerBound, UpperBound, ClauseElem);
1326+
}
1327+
UnboundClauses.clear();
1328+
}
1329+
}
13131330

1314-
return Overlaps.size() != 0;
1331+
return BindingChecker.checkOverlap();
13151332
}
13161333

13171334
void SemaHLSL::handleRootSignatureAttr(Decl *D, const ParsedAttr &AL) {

0 commit comments

Comments
 (0)