From f922c7aee91e3207ff96822f316afc112657349a Mon Sep 17 00:00:00 2001 From: Jessica Clarke Date: Tue, 23 Sep 2025 23:53:21 +0100 Subject: [PATCH] [SROA] Handle more alignment corner cases rather than asserting The existing downstream alignment handling stopped merging in unsplittable slices if they would be unaligned. However, in this case, we're already covering at least part of such slices with the current partition, and that partition contains unsplittable slices, so we cannot just stop early. This had the result that, if we ever hit that case, the next call to advance would fail the assertion that BeginOffset didn't go backwards. The current single-pass implementation from upstream is complicated to reason about, mixing correctness properties (handling of unsplittable slices and, in our extended code, alignment) with heuristics on where to split within those constraints. It's also done as a single pass which, whilst workable without alignment constraints, is not in general possible once alignment is required (at least not unless you want to introduce the concept of head padding). This new implementation introduces AllocaSplitCandidates, which pre-scans the slices to determine a set of possible split points, and AllocaSplitRequester, which takes a heuristic-based request from AllocaSlices::partition_iterator and returns a legal split point for it to use. It also allows multiple queries before committing to one, so the split point can be advanced and walked back as more slices are scanned. The intent is that, despite the new implementation, the returned partitions are unchanged for all cases that weren't previously incorrectly-handled. In particular, since slices only ever have alignment other than 1 for CHERI, non-CHERI should not see any changes. --- llvm/lib/Transforms/Scalar/SROA.cpp | 277 +++++++++++++++--- .../SROA/cheri-crash-unaligned-overlap.ll | 106 +++++++ 2 files changed, 338 insertions(+), 45 deletions(-) create mode 100644 llvm/test/Transforms/SROA/cheri-crash-unaligned-overlap.ll diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index b9b13ab4ac684..defbdd86cce13 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -24,6 +24,7 @@ #include "llvm/Transforms/Scalar/SROA.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/AddressRanges.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" @@ -857,6 +858,184 @@ class Partition { ArrayRef splitSliceTails() const { return SplitTails; } }; +class AllocaSplitCandidates { + friend class AllocaSplitRequester; + + AddressRanges CandidateSplits; + + // Some split endpoints are required as a simplification to ensure slices are + // kept aligned due to the current pre-computation implementation. + SmallVector RequiredSplits; + +public: + AllocaSplitCandidates(AllocaSlices::iterator SI, AllocaSlices::iterator SE) { + uint64_t Size = 0; + AddressRanges ExcludeRanges; + + // Compute the union of all unsplittable slices, and the max of all slices. + for (const auto &S : make_range(SI, SE)) { + Size = std::max(S.endOffset(), Size); + assert(S.endOffset() - S.beginOffset() > 0); + if (S.isSplittable() || S.endOffset() - S.beginOffset() < 2) + continue; + + ExcludeRanges.insert({S.beginOffset() + 1, S.endOffset()}); + LLVM_DEBUG(dbgs() << "Excluding split range [" << S.beginOffset() + 1 + << ", " << S.endOffset() - 1 + << "] due to unsplittable slice\n"); + } + + // Exclude ranges that would require introducing padding. That is, ensure + // that, for any candidate, it maintains the alignment of any unsplittable + // slices (splittable slices never have an alignment requirement) between + // it and the next candidate. + auto RSI = std::make_reverse_iterator(SE); + auto RSE = std::make_reverse_iterator(SI); + uint64_t Gap = ExcludeRanges.size(); + while (RSI != RSE) { + const auto *S = &*RSI++; + Align AlignReq = S->minAlignment(); + if (AlignReq == 1) + continue; + + // This slice has an alignment requirement. Keep walking back through the + // excluded ranges until we find a gap that can satisfy that alignment, + // along with the alignment for all other slices we encounter along the + // way. + uint64_t ExcludeEnd = S->beginOffset(); + uint64_t CurStart = alignDown(ExcludeEnd, AlignReq.value()); + bool FindGap = true; + while (FindGap) { + uint64_t GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end(); + uint64_t GapLast = + Gap == ExcludeRanges.size() ? Size : ExcludeRanges[Gap].start() - 1; + uint64_t AlignedGapBegin = alignTo(GapBegin, AlignReq); + uint64_t AlignedGapLast = alignDown(GapLast, AlignReq.value()); + while (Gap > 0) { + if (AlignedGapBegin <= AlignedGapLast && AlignedGapBegin <= CurStart) + break; + + --Gap; + GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end(); + GapLast = ExcludeRanges[Gap].start() - 1; + AlignedGapBegin = alignTo(GapBegin, AlignReq); + AlignedGapLast = alignDown(GapLast, AlignReq.value()); + } + + // This should always be true; either Gap > 0 and so we already checked + // this above, or Gap == 0 and so the offset is 0, which trivially + // satisfies this. + assert(AlignedGapBegin <= AlignedGapLast && + AlignedGapBegin <= CurStart && + "Could not find aligned split point!"); + CurStart = std::min(AlignedGapLast, CurStart); + FindGap = false; + + // Scan through all the slices that will be included between this split + // point and the previous one and check if any invalidate our choice of + // gap. + while (RSI != RSE && RSI->beginOffset() >= CurStart) { + S = &*RSI++; + if (S->minAlignment() <= AlignReq) + continue; + + AlignReq = S->minAlignment(); + AlignedGapBegin = alignTo(GapBegin, AlignReq); + AlignedGapLast = alignDown(GapLast, AlignReq.value()); + CurStart = alignDown(CurStart, AlignReq.value()); + if (AlignedGapBegin <= AlignedGapLast && AlignedGapBegin <= CurStart) + continue; + + FindGap = true; + } + } + + // Any split in (CurStart, ExcludeEnd] would not be satisfiable without + // padding. + if (CurStart < ExcludeEnd) { + ExcludeRanges.insert({CurStart + 1, ExcludeEnd + 1}); + LLVM_DEBUG(dbgs() << "Excluding split range [" << CurStart + 1 << ", " + << ExcludeEnd << "] due to aligned slices\n"); + } + + // We assume that we don't need to consider the alignment of any slices + // from this offset onwards on the next iteration (so that unaligned head + // slices can be split off regardless of what's in this tail), so enforce + // this by requiring a split at this point (which minimises the size of + // the partition that will include these slices). + if (CurStart > 0) + RequiredSplits.push_back(CurStart); + } + + for (size_t Gap = 0; Gap <= ExcludeRanges.size(); ++Gap) { + uint64_t GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end(); + uint64_t GapLast = + Gap == ExcludeRanges.size() ? Size : ExcludeRanges[Gap].start() - 1; + CandidateSplits.insert({GapBegin, GapLast + 1}); + LLVM_DEBUG(dbgs() << "Candidate split range [" << GapBegin << ", " + << GapLast << "]\n"); + } + + // Keep the required splits in ascending order + std::reverse(RequiredSplits.begin(), RequiredSplits.end()); + + LLVM_DEBUG({ + for (uint64_t Split : RequiredSplits) + dbgs() << "Required split " << Split << "\n"; + }); + } +}; + +class AllocaSplitRequester { + const AllocaSplitCandidates C; + + struct { + size_t CurCandidates = 0; + size_t NextRequired = 0; +#ifndef NDEBUG + uint64_t MinNextReq = 0; +#endif + } Committed, Staged; + +public: + AllocaSplitRequester(AllocaSlices::iterator SI, AllocaSlices::iterator SE) + : C(SI, SE) {} + + uint64_t requestNext(uint64_t Req, bool RoundUp) { + Staged = Committed; + + assert(Req < (C.CandidateSplits.end() - 1)->end()); + assert(Req >= Staged.MinNextReq); + + if (Staged.NextRequired < C.RequiredSplits.size()) + Req = std::min(Req, C.RequiredSplits[Staged.NextRequired]); + + while (Req >= C.CandidateSplits[Staged.CurCandidates].end()) + ++Staged.CurCandidates; + + if (Req < C.CandidateSplits[Staged.CurCandidates].start()) { + if (RoundUp) + Req = C.CandidateSplits[Staged.CurCandidates].start(); + else + Req = C.CandidateSplits[Staged.CurCandidates - 1].end() - 1; + } + + if (Staged.NextRequired < C.RequiredSplits.size()) { + if (Req == C.RequiredSplits[Staged.NextRequired]) + ++Staged.NextRequired; + else + assert(Req < C.RequiredSplits[Staged.NextRequired]); + } + +#ifndef NDEBUG + Staged.MinNextReq = Req + 1; +#endif + return Req; + } + + void commitNext() { Committed = Staged; } +}; + } // end anonymous namespace /// An iterator over partitions of the alloca's slices. @@ -880,6 +1059,10 @@ class AllocaSlices::partition_iterator /// We need to keep the end of the slices to know when to stop. AllocaSlices::iterator SE; + /// Calculates the legal split points for us to enact with whatever policy we + /// like. + AllocaSplitRequester R; + /// We also need to keep track of the maximum split end offset seen. /// FIXME: Do we really? uint64_t MaxSplitSliceEndOffset = 0; @@ -887,7 +1070,7 @@ class AllocaSlices::partition_iterator /// Sets the partition to be empty at given iterator, and sets the /// end iterator. partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE) - : P(SI), SE(SE) { + : P(SI), SE(SE), R(SI, SE) { // If not already at the end, advance our state to form the initial // partition. if (SI != SE) @@ -957,64 +1140,68 @@ class AllocaSlices::partition_iterator // If the we have split slices and the next slice is after a gap and is // not splittable immediately form an empty partition for the split - // slices up until the next slice begins. + // slices up until as close to the next slice as we can get, if possible. if (!P.SplitTails.empty() && P.SI->beginOffset() != P.EndOffset && !P.SI->isSplittable()) { - P.BeginOffset = P.EndOffset; - P.EndOffset = P.SI->beginOffset(); - return; + uint64_t NextOffset = + R.requestNext(P.SI->beginOffset(), /*RoundUp=*/false); + if (NextOffset > P.EndOffset) { + P.BeginOffset = P.EndOffset; + P.EndOffset = NextOffset; + R.commitNext(); + return; + } + assert(NextOffset == P.EndOffset && "requestNext went backwards!"); } } // OK, we need to consume new slices. Set the end offset based on the // current slice, and step SJ past it. The beginning offset of the - // partition is the beginning offset of the next slice unless we have - // pre-existing split slices that are continuing, in which case we begin - // at the prior end offset. - P.BeginOffset = P.SplitTails.empty() ? P.SI->beginOffset() : P.EndOffset; - P.EndOffset = P.SI->endOffset(); + // partition is as close to the beginning offset of the next slice as we + // can get, unless we have pre-existing split slices that are continuing, + // in which case we begin at the prior end offset. + if (P.SplitTails.empty() && P.SI->beginOffset() > P.EndOffset) { + P.BeginOffset = R.requestNext(P.SI->beginOffset(), /*RoundUp=*/false); + R.commitNext(); + } else + P.BeginOffset = P.EndOffset; + P.EndOffset = R.requestNext(P.SI->endOffset(), /*RoundUp=*/true); + bool Splittable = P.SI->isSplittable(); ++P.SJ; - // There are two strategies to form a partition based on whether the - // partition starts with an unsplittable slice or a splittable slice. - if (!P.SI->isSplittable()) { - // When we're forming an unsplittable region, it must always start at - // the first slice and will extend through its end. - assert(P.BeginOffset == P.SI->beginOffset()); - - // Form a partition including all of the overlapping slices with this - // unsplittable slice. - while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset && - P.SJ->isAligned(P.BeginOffset)) { - if (!P.SJ->isSplittable()) - P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); - ++P.SJ; + // Collect all the overlapping slices and grow the partition if possible. + // If we encounter an unsplittable slice, try to stop before it, otherwise + // stop as soon after it as possible. + Align AlignReq = P.SI->minAlignment(); + while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { + assert((P.SJ->isSplittable() || P.SJ->endOffset() <= P.EndOffset) && + "requestNext tried to split unsplittable slice!"); + if (Splittable) { + if (!P.SJ->isSplittable()) { + Splittable = false; + if (P.SJ->beginOffset() > P.BeginOffset) { + P.EndOffset = R.requestNext(P.SJ->beginOffset(), /*RoundUp=*/false); + if (P.EndOffset > P.BeginOffset) + break; + } + } + if (!P.SJ->isSplittable() || P.SJ->endOffset() > P.EndOffset) + P.EndOffset = R.requestNext(P.SJ->endOffset(), /*RoundUp=*/true); } - - // We have a partition across a set of overlapping unsplittable - // partitions. - return; - } - - // If we're starting with a splittable slice, then we need to form - // a synthetic partition spanning it and any other overlapping splittable - // splices. - assert(P.SI->isSplittable() && "Forming a splittable partition!"); - - // Collect all of the overlapping splittable slices. - while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset && - P.SJ->isSplittable()) { - P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); + if (P.SJ->minAlignment() > AlignReq) + AlignReq = P.SJ->minAlignment(); ++P.SJ; } - // Back upiP.EndOffset if we ended the span early when encountering an - // unsplittable slice. This synthesizes the early end offset of - // a partition spanning only splittable slices. - if (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { - assert(!P.SJ->isSplittable()); - P.EndOffset = P.SJ->beginOffset(); + // If we encountered an unsplittable slice we may have truncated the end of + // the partition to before its start and need to back up SJ. + while (P.SJ > P.SI && (P.SJ - 1)->beginOffset() >= P.EndOffset) { + assert(!Splittable && "Unwinding splittable partition!"); + --P.SJ; } + assert(isAligned(AlignReq, P.BeginOffset) && + "requestNext tried to create unaligned slice!"); + R.commitNext(); } public: diff --git a/llvm/test/Transforms/SROA/cheri-crash-unaligned-overlap.ll b/llvm/test/Transforms/SROA/cheri-crash-unaligned-overlap.ll new file mode 100644 index 0000000000000..90f97eac38d6d --- /dev/null +++ b/llvm/test/Transforms/SROA/cheri-crash-unaligned-overlap.ll @@ -0,0 +1,106 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt -S -passes=sroa < %s | FileCheck %s +target datalayout = "e-m:e-pf200:128:128:128:64-p:64:64-i64:64-i128:128-n32:64-S128-A200-P200-G200" + +declare void @llvm.lifetime.start.p200(i64 immarg, ptr addrspace(200)) + +;; This previously crashed with: +;; +;; Assertion `BeginOffset < EndOffset && "Partitions must span some bytes!"' failed. +;; +;; This test is ultimately derived from specific usage within Chromium of: +;; +;; std::string("array of ") + (condition ? "mutable " : "") +;; +;; that was optimised to IR characterised by the partition access patterns in +;; this test and crashed Morello LLVM (note for example that the original +;; reduced IR loaded rather than stored a pointer, but the latter avoids issues +;; of poison, and the overlap has been reduced to just a single i16 at the +;; boundary). +define void @head_i16() { +; CHECK-LABEL: define void @head_i16() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF:%.*]] = alloca [32 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) [[BUF]]) +; CHECK-NEXT: [[BUF_16_P16_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 16 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_16_P16_SROA_IDX]], align 16 +; CHECK-NEXT: [[BUF_15_P15_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 15 +; CHECK-NEXT: store i16 0, ptr addrspace(200) [[BUF_15_P15_SROA_IDX]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + %p16 = getelementptr [32 x i8], ptr addrspace(200) %buf, i64 0, i64 16 + store ptr addrspace(200) null, ptr addrspace(200) %p16, align 16 + %p15 = getelementptr [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store i16 0, ptr addrspace(200) %p15, align 1 + ret void +} + +;; Test we don't crash even when the overlap is due to an unaligned capability +;; (even if the align 16 is UB here). +define void @head_cap_misaligned() { +; CHECK-LABEL: define void @head_cap_misaligned() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF:%.*]] = alloca [32 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) [[BUF]]) +; CHECK-NEXT: [[BUF_16_P16_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 16 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_16_P16_SROA_IDX]], align 16 +; CHECK-NEXT: [[BUF_15_P15_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 15 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_15_P15_SROA_IDX]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + %p16 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 16 + store ptr addrspace(200) null, ptr addrspace(200) %p16, align 16 + %p15 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store ptr addrspace(200) null, ptr addrspace(200) %p15, align 16 + ret void +} + +;; Test we don't crash even when the overlap is due to an unaligned capability +;; that's marked as unaligned. +define void @head_cap_unaligned() { +; CHECK-LABEL: define void @head_cap_unaligned() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF:%.*]] = alloca [32 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) [[BUF]]) +; CHECK-NEXT: [[BUF_16_P16_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 16 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_16_P16_SROA_IDX]], align 16 +; CHECK-NEXT: [[BUF_15_P15_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 15 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_15_P15_SROA_IDX]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + %p16 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 16 + store ptr addrspace(200) null, ptr addrspace(200) %p16, align 16 + %p15 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store ptr addrspace(200) null, ptr addrspace(200) %p15, align 1 + ret void +} + +;; Test we don't crash when the overlap is past the end (and in fact omit the +;; tail padding as implicit). +define void @tail_i16() { +; CHECK-LABEL: define void @tail_i16() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF_SROA_0:%.*]] = alloca [17 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 17, ptr addrspace(200) [[BUF_SROA_0]]) +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_SROA_0]], align 16 +; CHECK-NEXT: [[BUF_SROA_0_15_P15_SROA_IDX1:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF_SROA_0]], i64 15 +; CHECK-NEXT: store i16 0, ptr addrspace(200) [[BUF_SROA_0_15_P15_SROA_IDX1]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + store ptr addrspace(200) null, ptr addrspace(200) %buf, align 16 + %p15 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store i16 0, ptr addrspace(200) %p15, align 1 + ret void +}