Skip to content

Commit a7bf10f

Browse files
committed
proto: using std::pair mapping with sorting outside of function
1 parent c3305b6 commit a7bf10f

File tree

3 files changed

+57
-26
lines changed

3 files changed

+57
-26
lines changed

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,9 +1085,10 @@ bool SemaHLSL::handleRootSignatureElements(
10851085
ArrayRef<hlsl::RootSignatureElement> Elements) {
10861086
using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
10871087
using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;
1088+
using InfoPair = std::pair<RangeInfo, const hlsl::RootSignatureElement *>;
10881089

10891090
// 1. Collect RangeInfos
1090-
llvm::SmallVector<RangeInfo> Infos;
1091+
llvm::SmallVector<InfoPair> InfoPairs;
10911092
for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
10921093
const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
10931094
if (const auto *Descriptor =
@@ -1101,8 +1102,7 @@ bool SemaHLSL::handleRootSignatureElements(
11011102
Info.Space = Descriptor->Space;
11021103
Info.Visibility = Descriptor->Visibility;
11031104

1104-
Info.Cookie = static_cast<void *>(&RootSigElem);
1105-
Infos.push_back(Info);
1105+
InfoPairs.push_back({Info, &RootSigElem});
11061106
} else if (const auto *Constants =
11071107
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
11081108
RangeInfo Info;
@@ -1113,8 +1113,7 @@ bool SemaHLSL::handleRootSignatureElements(
11131113
Info.Space = Constants->Space;
11141114
Info.Visibility = Constants->Visibility;
11151115

1116-
Info.Cookie = static_cast<void *>(&RootSigElem);
1117-
Infos.push_back(Info);
1116+
InfoPairs.push_back({Info, &RootSigElem});
11181117
} else if (const auto *Sampler =
11191118
std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
11201119
RangeInfo Info;
@@ -1125,8 +1124,7 @@ bool SemaHLSL::handleRootSignatureElements(
11251124
Info.Space = Sampler->Space;
11261125
Info.Visibility = Sampler->Visibility;
11271126

1128-
Info.Cookie = static_cast<void *>(&RootSigElem);
1129-
Infos.push_back(Info);
1127+
InfoPairs.push_back({Info, &RootSigElem});
11301128
} else if (const auto *Clause =
11311129
std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
11321130
&Elem)) {
@@ -1142,30 +1140,59 @@ bool SemaHLSL::handleRootSignatureElements(
11421140
Info.Space = Clause->Space;
11431141

11441142
// Note: Clause does not hold the visibility this will need to
1145-
Info.Cookie = static_cast<void *>(&RootSigElem);
1146-
Infos.push_back(Info);
1143+
InfoPairs.push_back({Info, &RootSigElem});
11471144
} else if (const auto *Table =
11481145
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
11491146
// Table holds the Visibility of all owned Clauses in Table, so iterate
11501147
// owned Clauses and update their corresponding RangeInfo
1151-
assert(Table->NumClauses <= Infos.size() && "RootElement");
1148+
assert(Table->NumClauses <= InfoPairs.size() && "RootElement");
11521149
// The last Table->NumClauses elements of Infos are the owned Clauses
11531150
// generated RangeInfo
11541151
auto TableInfos =
1155-
MutableArrayRef<RangeInfo>(Infos).take_back(Table->NumClauses);
1156-
for (RangeInfo &Info : TableInfos)
1157-
Info.Visibility = Table->Visibility;
1152+
MutableArrayRef<InfoPair>(InfoPairs).take_back(Table->NumClauses);
1153+
for (InfoPair &Pair : TableInfos)
1154+
Pair.first.Visibility = Table->Visibility;
11581155
}
11591156
}
11601157

1161-
// Helper to report diagnostics
1162-
auto ReportOverlap = [this](OverlappingRanges Overlap) {
1158+
// Sort as specified
1159+
auto ComparePairs = [](InfoPair A, InfoPair B) { return A.first < B.first; };
1160+
1161+
std::sort(InfoPairs.begin(), InfoPairs.end(), ComparePairs);
1162+
1163+
llvm::SmallVector<RangeInfo> Infos;
1164+
for (const InfoPair &Pair : InfoPairs)
1165+
Infos.push_back(Pair.first);
1166+
1167+
// Helpers to report diagnostics
1168+
using ElemPair = std::pair<const hlsl::RootSignatureElement *,
1169+
const hlsl::RootSignatureElement *>;
1170+
auto GetElemPair = [&Infos, &InfoPairs](
1171+
OverlappingRanges Overlap) -> ElemPair {
1172+
auto InfoB = std::lower_bound(Infos.begin(), Infos.end(), *Overlap.B);
1173+
auto DistB = std::distance(Infos.begin(), InfoB);
1174+
auto PairB = InfoPairs.begin();
1175+
std::advance(PairB, DistB);
1176+
1177+
auto InfoA = std::lower_bound(InfoB, Infos.end(), *Overlap.A);
1178+
if (InfoA == InfoB)
1179+
InfoA++;
1180+
auto DistA = std::distance(InfoB, InfoA);
1181+
auto PairA = PairB;
1182+
std::advance(PairA, DistA);
1183+
1184+
return {PairA->second, PairB->second};
1185+
};
1186+
1187+
auto ReportOverlap = [this, &GetElemPair](OverlappingRanges Overlap) {
1188+
auto Pair = GetElemPair(Overlap);
11631189
const RangeInfo *Info = Overlap.A;
1190+
const hlsl::RootSignatureElement *Elem = Pair.first;
11641191
const RangeInfo *OInfo = Overlap.B;
1192+
11651193
auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
11661194
? OInfo->Visibility
11671195
: Info->Visibility;
1168-
auto Elem = static_cast<const hlsl::RootSignatureElement *>(Info->Cookie);
11691196
this->Diag(Elem->getLocation(), diag::err_hlsl_resource_range_overlap)
11701197
<< llvm::to_underlying(Info->Class) << Info->LowerBound
11711198
<< /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
@@ -1174,7 +1201,7 @@ bool SemaHLSL::handleRootSignatureElements(
11741201
<< /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
11751202
<< OInfo->UpperBound << Info->Space << CommonVis;
11761203

1177-
auto OElem = static_cast<const hlsl::RootSignatureElement *>(OInfo->Cookie);
1204+
const hlsl::RootSignatureElement *OElem = Pair.second;
11781205
this->Diag(OElem->getLocation(), diag::note_hlsl_resource_range_here);
11791206
};
11801207

llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,17 @@ struct RangeInfo {
5252
uint32_t Space;
5353
llvm::dxbc::ShaderVisibility Visibility;
5454

55-
// Retain information for diagnostic reporting
56-
void *Cookie;
55+
bool operator==(const RangeInfo &RHS) {
56+
return std::tie(LowerBound, UpperBound, Class, Space, Visibility) ==
57+
std::tie(RHS.LowerBound, RHS.UpperBound, RHS.Class, RHS.Space,
58+
RHS.Visibility);
59+
}
60+
61+
bool operator<(const RangeInfo &RHS) {
62+
return std::tie(Class, Space, LowerBound, UpperBound, Visibility) <
63+
std::tie(RHS.Class, RHS.Space, RHS.LowerBound, RHS.UpperBound,
64+
RHS.Visibility);
65+
}
5766
};
5867

5968
class ResourceRange {
@@ -137,7 +146,7 @@ struct OverlappingRanges {
137146
/// ResourceRange
138147
/// B: Check for overlap with any overlapping Visibility ResourceRange
139148
llvm::SmallVector<OverlappingRanges>
140-
findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos);
149+
findOverlappingRanges(ArrayRef<RangeInfo> Infos);
141150

142151
} // namespace rootsig
143152
} // namespace hlsl

llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,11 @@ std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
244244
}
245245

246246
llvm::SmallVector<OverlappingRanges>
247-
findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos) {
247+
findOverlappingRanges(ArrayRef<RangeInfo> Infos) {
248248
// 1. The user has provided the corresponding range information
249249
llvm::SmallVector<OverlappingRanges> Overlaps;
250250
using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>;
251251

252-
// 2. Sort the RangeInfo's by their GroupT to form groupings
253-
std::sort(Infos.begin(), Infos.end(), [](RangeInfo A, RangeInfo B) {
254-
return std::tie(A.Class, A.Space) < std::tie(B.Class, B.Space);
255-
});
256-
257252
// 3. First we will init our state to track:
258253
if (Infos.size() == 0)
259254
return Overlaps; // No ranges to overlap

0 commit comments

Comments
 (0)