Skip to content

Commit 8261528

Browse files
jrtc27resistor
authored andcommitted
[SROA] Handle more alignment corner cases rather than asserting
The existing downstream alignment handling stopped merging in unsplittable slices if they would be unaligned. However, in this case, we're already covering at least part of such slices with the current partition, and that partition contains unsplittable slices, so we cannot just stop early. This had the result that, if we ever hit that case, the next call to advance would fail the assertion that BeginOffset didn't go backwards. The current single-pass implementation from upstream is complicated to reason about, mixing correctness properties (handling of unsplittable slices and, in our extended code, alignment) with heuristics on where to split within those constraints. It's also done as a single pass which, whilst workable without alignment constraints, is not in general possible once alignment is required (at least not unless you want to introduce the concept of head padding). This new implementation introduces AllocaSplitCandidates, which pre-scans the slices to determine a set of possible split points, and AllocaSplitRequester, which takes a heuristic-based request from AllocaSlices::partition_iterator and returns a legal split point for it to use. It also allows multiple queries before committing to one, so the split point can be advanced and walked back as more slices are scanned. The intent is that, despite the new implementation, the returned partitions are unchanged for all cases that weren't previously incorrectly-handled. In particular, since slices only ever have alignment other than 1 for CHERI, non-CHERI should not see any changes.
1 parent 1f87ff7 commit 8261528

File tree

2 files changed

+338
-45
lines changed

2 files changed

+338
-45
lines changed

llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 232 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#include "llvm/Transforms/Scalar/SROA.h"
2626
#include "llvm/ADT/APInt.h"
27+
#include "llvm/ADT/AddressRanges.h"
2728
#include "llvm/ADT/ArrayRef.h"
2829
#include "llvm/ADT/DenseMap.h"
2930
#include "llvm/ADT/MapVector.h"
@@ -857,6 +858,184 @@ class Partition {
857858
ArrayRef<Slice *> splitSliceTails() const { return SplitTails; }
858859
};
859860

861+
class AllocaSplitCandidates {
862+
friend class AllocaSplitRequester;
863+
864+
AddressRanges CandidateSplits;
865+
866+
// Some split endpoints are required as a simplification to ensure slices are
867+
// kept aligned due to the current pre-computation implementation.
868+
SmallVector<uint64_t, 8> RequiredSplits;
869+
870+
public:
871+
AllocaSplitCandidates(AllocaSlices::iterator SI, AllocaSlices::iterator SE) {
872+
uint64_t Size = 0;
873+
AddressRanges ExcludeRanges;
874+
875+
// Compute the union of all unsplittable slices, and the max of all slices.
876+
for (const auto &S : make_range(SI, SE)) {
877+
Size = std::max(S.endOffset(), Size);
878+
assert(S.endOffset() - S.beginOffset() > 0);
879+
if (S.isSplittable() || S.endOffset() - S.beginOffset() < 2)
880+
continue;
881+
882+
ExcludeRanges.insert({S.beginOffset() + 1, S.endOffset()});
883+
LLVM_DEBUG(dbgs() << "Excluding split range [" << S.beginOffset() + 1
884+
<< ", " << S.endOffset() - 1
885+
<< "] due to unsplittable slice\n");
886+
}
887+
888+
// Exclude ranges that would require introducing padding. That is, ensure
889+
// that, for any candidate, it maintains the alignment of any unsplittable
890+
// slices (splittable slices never have an alignment requirement) between
891+
// it and the next candidate.
892+
auto RSI = std::make_reverse_iterator(SE);
893+
auto RSE = std::make_reverse_iterator(SI);
894+
uint64_t Gap = ExcludeRanges.size();
895+
while (RSI != RSE) {
896+
const auto *S = &*RSI++;
897+
Align AlignReq = S->minAlignment();
898+
if (AlignReq == 1)
899+
continue;
900+
901+
// This slice has an alignment requirement. Keep walking back through the
902+
// excluded ranges until we find a gap that can satisfy that alignment,
903+
// along with the alignment for all other slices we encounter along the
904+
// way.
905+
uint64_t ExcludeEnd = S->beginOffset();
906+
uint64_t CurStart = alignDown(ExcludeEnd, AlignReq.value());
907+
bool FindGap = true;
908+
while (FindGap) {
909+
uint64_t GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end();
910+
uint64_t GapLast =
911+
Gap == ExcludeRanges.size() ? Size : ExcludeRanges[Gap].start() - 1;
912+
uint64_t AlignedGapBegin = alignTo(GapBegin, AlignReq);
913+
uint64_t AlignedGapLast = alignDown(GapLast, AlignReq.value());
914+
while (Gap > 0) {
915+
if (AlignedGapBegin <= AlignedGapLast && AlignedGapBegin <= CurStart)
916+
break;
917+
918+
--Gap;
919+
GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end();
920+
GapLast = ExcludeRanges[Gap].start() - 1;
921+
AlignedGapBegin = alignTo(GapBegin, AlignReq);
922+
AlignedGapLast = alignDown(GapLast, AlignReq.value());
923+
}
924+
925+
// This should always be true; either Gap > 0 and so we already checked
926+
// this above, or Gap == 0 and so the offset is 0, which trivially
927+
// satisfies this.
928+
assert(AlignedGapBegin <= AlignedGapLast &&
929+
AlignedGapBegin <= CurStart &&
930+
"Could not find aligned split point!");
931+
CurStart = std::min(AlignedGapLast, CurStart);
932+
FindGap = false;
933+
934+
// Scan through all the slices that will be included between this split
935+
// point and the previous one and check if any invalidate our choice of
936+
// gap.
937+
while (RSI != RSE && RSI->beginOffset() >= CurStart) {
938+
S = &*RSI++;
939+
if (S->minAlignment() <= AlignReq)
940+
continue;
941+
942+
AlignReq = S->minAlignment();
943+
AlignedGapBegin = alignTo(GapBegin, AlignReq);
944+
AlignedGapLast = alignDown(GapLast, AlignReq.value());
945+
CurStart = alignDown(CurStart, AlignReq.value());
946+
if (AlignedGapBegin <= AlignedGapLast && AlignedGapBegin <= CurStart)
947+
continue;
948+
949+
FindGap = true;
950+
}
951+
}
952+
953+
// Any split in (CurStart, ExcludeEnd] would not be satisfiable without
954+
// padding.
955+
if (CurStart < ExcludeEnd) {
956+
ExcludeRanges.insert({CurStart + 1, ExcludeEnd + 1});
957+
LLVM_DEBUG(dbgs() << "Excluding split range [" << CurStart + 1 << ", "
958+
<< ExcludeEnd << "] due to aligned slices\n");
959+
}
960+
961+
// We assume that we don't need to consider the alignment of any slices
962+
// from this offset onwards on the next iteration (so that unaligned head
963+
// slices can be split off regardless of what's in this tail), so enforce
964+
// this by requiring a split at this point (which minimises the size of
965+
// the partition that will include these slices).
966+
if (CurStart > 0)
967+
RequiredSplits.push_back(CurStart);
968+
}
969+
970+
for (size_t Gap = 0; Gap <= ExcludeRanges.size(); ++Gap) {
971+
uint64_t GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end();
972+
uint64_t GapLast =
973+
Gap == ExcludeRanges.size() ? Size : ExcludeRanges[Gap].start() - 1;
974+
CandidateSplits.insert({GapBegin, GapLast + 1});
975+
LLVM_DEBUG(dbgs() << "Candidate split range [" << GapBegin << ", "
976+
<< GapLast << "]\n");
977+
}
978+
979+
// Keep the required splits in ascending order
980+
std::reverse(RequiredSplits.begin(), RequiredSplits.end());
981+
982+
LLVM_DEBUG({
983+
for (uint64_t Split : RequiredSplits)
984+
dbgs() << "Required split " << Split << "\n";
985+
});
986+
}
987+
};
988+
989+
class AllocaSplitRequester {
990+
const AllocaSplitCandidates C;
991+
992+
struct {
993+
size_t CurCandidates = 0;
994+
size_t NextRequired = 0;
995+
#ifndef NDEBUG
996+
uint64_t MinNextReq = 0;
997+
#endif
998+
} Committed, Staged;
999+
1000+
public:
1001+
AllocaSplitRequester(AllocaSlices::iterator SI, AllocaSlices::iterator SE)
1002+
: C(SI, SE) {}
1003+
1004+
uint64_t requestNext(uint64_t Req, bool RoundUp) {
1005+
Staged = Committed;
1006+
1007+
assert(Req < (C.CandidateSplits.end() - 1)->end());
1008+
assert(Req >= Staged.MinNextReq);
1009+
1010+
if (Staged.NextRequired < C.RequiredSplits.size())
1011+
Req = std::min(Req, C.RequiredSplits[Staged.NextRequired]);
1012+
1013+
while (Req >= C.CandidateSplits[Staged.CurCandidates].end())
1014+
++Staged.CurCandidates;
1015+
1016+
if (Req < C.CandidateSplits[Staged.CurCandidates].start()) {
1017+
if (RoundUp)
1018+
Req = C.CandidateSplits[Staged.CurCandidates].start();
1019+
else
1020+
Req = C.CandidateSplits[Staged.CurCandidates - 1].end() - 1;
1021+
}
1022+
1023+
if (Staged.NextRequired < C.RequiredSplits.size()) {
1024+
if (Req == C.RequiredSplits[Staged.NextRequired])
1025+
++Staged.NextRequired;
1026+
else
1027+
assert(Req < C.RequiredSplits[Staged.NextRequired]);
1028+
}
1029+
1030+
#ifndef NDEBUG
1031+
Staged.MinNextReq = Req + 1;
1032+
#endif
1033+
return Req;
1034+
}
1035+
1036+
void commitNext() { Committed = Staged; }
1037+
};
1038+
8601039
} // end anonymous namespace
8611040

8621041
/// An iterator over partitions of the alloca's slices.
@@ -880,14 +1059,18 @@ class AllocaSlices::partition_iterator
8801059
/// We need to keep the end of the slices to know when to stop.
8811060
AllocaSlices::iterator SE;
8821061

1062+
/// Calculates the legal split points for us to enact with whatever policy we
1063+
/// like.
1064+
AllocaSplitRequester R;
1065+
8831066
/// We also need to keep track of the maximum split end offset seen.
8841067
/// FIXME: Do we really?
8851068
uint64_t MaxSplitSliceEndOffset = 0;
8861069

8871070
/// Sets the partition to be empty at given iterator, and sets the
8881071
/// end iterator.
8891072
partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE)
890-
: P(SI), SE(SE) {
1073+
: P(SI), SE(SE), R(SI, SE) {
8911074
// If not already at the end, advance our state to form the initial
8921075
// partition.
8931076
if (SI != SE)
@@ -957,64 +1140,68 @@ class AllocaSlices::partition_iterator
9571140

9581141
// If the we have split slices and the next slice is after a gap and is
9591142
// not splittable immediately form an empty partition for the split
960-
// slices up until the next slice begins.
1143+
// slices up until as close to the next slice as we can get, if possible.
9611144
if (!P.SplitTails.empty() && P.SI->beginOffset() != P.EndOffset &&
9621145
!P.SI->isSplittable()) {
963-
P.BeginOffset = P.EndOffset;
964-
P.EndOffset = P.SI->beginOffset();
965-
return;
1146+
uint64_t NextOffset =
1147+
R.requestNext(P.SI->beginOffset(), /*RoundUp=*/false);
1148+
if (NextOffset > P.EndOffset) {
1149+
P.BeginOffset = P.EndOffset;
1150+
P.EndOffset = NextOffset;
1151+
R.commitNext();
1152+
return;
1153+
}
1154+
assert(NextOffset == P.EndOffset && "requestNext went backwards!");
9661155
}
9671156
}
9681157

9691158
// OK, we need to consume new slices. Set the end offset based on the
9701159
// current slice, and step SJ past it. The beginning offset of the
971-
// partition is the beginning offset of the next slice unless we have
972-
// pre-existing split slices that are continuing, in which case we begin
973-
// at the prior end offset.
974-
P.BeginOffset = P.SplitTails.empty() ? P.SI->beginOffset() : P.EndOffset;
975-
P.EndOffset = P.SI->endOffset();
1160+
// partition is as close to the beginning offset of the next slice as we
1161+
// can get, unless we have pre-existing split slices that are continuing,
1162+
// in which case we begin at the prior end offset.
1163+
if (P.SplitTails.empty() && P.SI->beginOffset() > P.EndOffset) {
1164+
P.BeginOffset = R.requestNext(P.SI->beginOffset(), /*RoundUp=*/false);
1165+
R.commitNext();
1166+
} else
1167+
P.BeginOffset = P.EndOffset;
1168+
P.EndOffset = R.requestNext(P.SI->endOffset(), /*RoundUp=*/true);
1169+
bool Splittable = P.SI->isSplittable();
9761170
++P.SJ;
9771171

978-
// There are two strategies to form a partition based on whether the
979-
// partition starts with an unsplittable slice or a splittable slice.
980-
if (!P.SI->isSplittable()) {
981-
// When we're forming an unsplittable region, it must always start at
982-
// the first slice and will extend through its end.
983-
assert(P.BeginOffset == P.SI->beginOffset());
984-
985-
// Form a partition including all of the overlapping slices with this
986-
// unsplittable slice.
987-
while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset &&
988-
P.SJ->isAligned(P.BeginOffset)) {
989-
if (!P.SJ->isSplittable())
990-
P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset());
991-
++P.SJ;
1172+
// Collect all the overlapping slices and grow the partition if possible.
1173+
// If we encounter an unsplittable slice, try to stop before it, otherwise
1174+
// stop as soon after it as possible.
1175+
Align AlignReq = P.SI->minAlignment();
1176+
while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) {
1177+
assert((P.SJ->isSplittable() || P.SJ->endOffset() <= P.EndOffset) &&
1178+
"requestNext tried to split unsplittable slice!");
1179+
if (Splittable) {
1180+
if (!P.SJ->isSplittable()) {
1181+
Splittable = false;
1182+
if (P.SJ->beginOffset() > P.BeginOffset) {
1183+
P.EndOffset = R.requestNext(P.SJ->beginOffset(), /*RoundUp=*/false);
1184+
if (P.EndOffset > P.BeginOffset)
1185+
break;
1186+
}
1187+
}
1188+
if (!P.SJ->isSplittable() || P.SJ->endOffset() > P.EndOffset)
1189+
P.EndOffset = R.requestNext(P.SJ->endOffset(), /*RoundUp=*/true);
9921190
}
993-
994-
// We have a partition across a set of overlapping unsplittable
995-
// partitions.
996-
return;
997-
}
998-
999-
// If we're starting with a splittable slice, then we need to form
1000-
// a synthetic partition spanning it and any other overlapping splittable
1001-
// splices.
1002-
assert(P.SI->isSplittable() && "Forming a splittable partition!");
1003-
1004-
// Collect all of the overlapping splittable slices.
1005-
while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset &&
1006-
P.SJ->isSplittable()) {
1007-
P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset());
1191+
if (P.SJ->minAlignment() > AlignReq)
1192+
AlignReq = P.SJ->minAlignment();
10081193
++P.SJ;
10091194
}
10101195

1011-
// Back upiP.EndOffset if we ended the span early when encountering an
1012-
// unsplittable slice. This synthesizes the early end offset of
1013-
// a partition spanning only splittable slices.
1014-
if (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) {
1015-
assert(!P.SJ->isSplittable());
1016-
P.EndOffset = P.SJ->beginOffset();
1196+
// If we encountered an unsplittable slice we may have truncated the end of
1197+
// the partition to before its start and need to back up SJ.
1198+
while (P.SJ > P.SI && (P.SJ - 1)->beginOffset() >= P.EndOffset) {
1199+
assert(!Splittable && "Unwinding splittable partition!");
1200+
--P.SJ;
10171201
}
1202+
assert(isAligned(AlignReq, P.BeginOffset) &&
1203+
"requestNext tried to create unaligned slice!");
1204+
R.commitNext();
10181205
}
10191206

10201207
public:

0 commit comments

Comments
 (0)