diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index b9b13ab4ac684..defbdd86cce13 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -24,6 +24,7 @@ #include "llvm/Transforms/Scalar/SROA.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/AddressRanges.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" @@ -857,6 +858,184 @@ class Partition { ArrayRef splitSliceTails() const { return SplitTails; } }; +class AllocaSplitCandidates { + friend class AllocaSplitRequester; + + AddressRanges CandidateSplits; + + // Some split endpoints are required as a simplification to ensure slices are + // kept aligned due to the current pre-computation implementation. + SmallVector RequiredSplits; + +public: + AllocaSplitCandidates(AllocaSlices::iterator SI, AllocaSlices::iterator SE) { + uint64_t Size = 0; + AddressRanges ExcludeRanges; + + // Compute the union of all unsplittable slices, and the max of all slices. + for (const auto &S : make_range(SI, SE)) { + Size = std::max(S.endOffset(), Size); + assert(S.endOffset() - S.beginOffset() > 0); + if (S.isSplittable() || S.endOffset() - S.beginOffset() < 2) + continue; + + ExcludeRanges.insert({S.beginOffset() + 1, S.endOffset()}); + LLVM_DEBUG(dbgs() << "Excluding split range [" << S.beginOffset() + 1 + << ", " << S.endOffset() - 1 + << "] due to unsplittable slice\n"); + } + + // Exclude ranges that would require introducing padding. That is, ensure + // that, for any candidate, it maintains the alignment of any unsplittable + // slices (splittable slices never have an alignment requirement) between + // it and the next candidate. + auto RSI = std::make_reverse_iterator(SE); + auto RSE = std::make_reverse_iterator(SI); + uint64_t Gap = ExcludeRanges.size(); + while (RSI != RSE) { + const auto *S = &*RSI++; + Align AlignReq = S->minAlignment(); + if (AlignReq == 1) + continue; + + // This slice has an alignment requirement. Keep walking back through the + // excluded ranges until we find a gap that can satisfy that alignment, + // along with the alignment for all other slices we encounter along the + // way. + uint64_t ExcludeEnd = S->beginOffset(); + uint64_t CurStart = alignDown(ExcludeEnd, AlignReq.value()); + bool FindGap = true; + while (FindGap) { + uint64_t GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end(); + uint64_t GapLast = + Gap == ExcludeRanges.size() ? Size : ExcludeRanges[Gap].start() - 1; + uint64_t AlignedGapBegin = alignTo(GapBegin, AlignReq); + uint64_t AlignedGapLast = alignDown(GapLast, AlignReq.value()); + while (Gap > 0) { + if (AlignedGapBegin <= AlignedGapLast && AlignedGapBegin <= CurStart) + break; + + --Gap; + GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end(); + GapLast = ExcludeRanges[Gap].start() - 1; + AlignedGapBegin = alignTo(GapBegin, AlignReq); + AlignedGapLast = alignDown(GapLast, AlignReq.value()); + } + + // This should always be true; either Gap > 0 and so we already checked + // this above, or Gap == 0 and so the offset is 0, which trivially + // satisfies this. + assert(AlignedGapBegin <= AlignedGapLast && + AlignedGapBegin <= CurStart && + "Could not find aligned split point!"); + CurStart = std::min(AlignedGapLast, CurStart); + FindGap = false; + + // Scan through all the slices that will be included between this split + // point and the previous one and check if any invalidate our choice of + // gap. + while (RSI != RSE && RSI->beginOffset() >= CurStart) { + S = &*RSI++; + if (S->minAlignment() <= AlignReq) + continue; + + AlignReq = S->minAlignment(); + AlignedGapBegin = alignTo(GapBegin, AlignReq); + AlignedGapLast = alignDown(GapLast, AlignReq.value()); + CurStart = alignDown(CurStart, AlignReq.value()); + if (AlignedGapBegin <= AlignedGapLast && AlignedGapBegin <= CurStart) + continue; + + FindGap = true; + } + } + + // Any split in (CurStart, ExcludeEnd] would not be satisfiable without + // padding. + if (CurStart < ExcludeEnd) { + ExcludeRanges.insert({CurStart + 1, ExcludeEnd + 1}); + LLVM_DEBUG(dbgs() << "Excluding split range [" << CurStart + 1 << ", " + << ExcludeEnd << "] due to aligned slices\n"); + } + + // We assume that we don't need to consider the alignment of any slices + // from this offset onwards on the next iteration (so that unaligned head + // slices can be split off regardless of what's in this tail), so enforce + // this by requiring a split at this point (which minimises the size of + // the partition that will include these slices). + if (CurStart > 0) + RequiredSplits.push_back(CurStart); + } + + for (size_t Gap = 0; Gap <= ExcludeRanges.size(); ++Gap) { + uint64_t GapBegin = Gap == 0 ? 0 : ExcludeRanges[Gap - 1].end(); + uint64_t GapLast = + Gap == ExcludeRanges.size() ? Size : ExcludeRanges[Gap].start() - 1; + CandidateSplits.insert({GapBegin, GapLast + 1}); + LLVM_DEBUG(dbgs() << "Candidate split range [" << GapBegin << ", " + << GapLast << "]\n"); + } + + // Keep the required splits in ascending order + std::reverse(RequiredSplits.begin(), RequiredSplits.end()); + + LLVM_DEBUG({ + for (uint64_t Split : RequiredSplits) + dbgs() << "Required split " << Split << "\n"; + }); + } +}; + +class AllocaSplitRequester { + const AllocaSplitCandidates C; + + struct { + size_t CurCandidates = 0; + size_t NextRequired = 0; +#ifndef NDEBUG + uint64_t MinNextReq = 0; +#endif + } Committed, Staged; + +public: + AllocaSplitRequester(AllocaSlices::iterator SI, AllocaSlices::iterator SE) + : C(SI, SE) {} + + uint64_t requestNext(uint64_t Req, bool RoundUp) { + Staged = Committed; + + assert(Req < (C.CandidateSplits.end() - 1)->end()); + assert(Req >= Staged.MinNextReq); + + if (Staged.NextRequired < C.RequiredSplits.size()) + Req = std::min(Req, C.RequiredSplits[Staged.NextRequired]); + + while (Req >= C.CandidateSplits[Staged.CurCandidates].end()) + ++Staged.CurCandidates; + + if (Req < C.CandidateSplits[Staged.CurCandidates].start()) { + if (RoundUp) + Req = C.CandidateSplits[Staged.CurCandidates].start(); + else + Req = C.CandidateSplits[Staged.CurCandidates - 1].end() - 1; + } + + if (Staged.NextRequired < C.RequiredSplits.size()) { + if (Req == C.RequiredSplits[Staged.NextRequired]) + ++Staged.NextRequired; + else + assert(Req < C.RequiredSplits[Staged.NextRequired]); + } + +#ifndef NDEBUG + Staged.MinNextReq = Req + 1; +#endif + return Req; + } + + void commitNext() { Committed = Staged; } +}; + } // end anonymous namespace /// An iterator over partitions of the alloca's slices. @@ -880,6 +1059,10 @@ class AllocaSlices::partition_iterator /// We need to keep the end of the slices to know when to stop. AllocaSlices::iterator SE; + /// Calculates the legal split points for us to enact with whatever policy we + /// like. + AllocaSplitRequester R; + /// We also need to keep track of the maximum split end offset seen. /// FIXME: Do we really? uint64_t MaxSplitSliceEndOffset = 0; @@ -887,7 +1070,7 @@ class AllocaSlices::partition_iterator /// Sets the partition to be empty at given iterator, and sets the /// end iterator. partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE) - : P(SI), SE(SE) { + : P(SI), SE(SE), R(SI, SE) { // If not already at the end, advance our state to form the initial // partition. if (SI != SE) @@ -957,64 +1140,68 @@ class AllocaSlices::partition_iterator // If the we have split slices and the next slice is after a gap and is // not splittable immediately form an empty partition for the split - // slices up until the next slice begins. + // slices up until as close to the next slice as we can get, if possible. if (!P.SplitTails.empty() && P.SI->beginOffset() != P.EndOffset && !P.SI->isSplittable()) { - P.BeginOffset = P.EndOffset; - P.EndOffset = P.SI->beginOffset(); - return; + uint64_t NextOffset = + R.requestNext(P.SI->beginOffset(), /*RoundUp=*/false); + if (NextOffset > P.EndOffset) { + P.BeginOffset = P.EndOffset; + P.EndOffset = NextOffset; + R.commitNext(); + return; + } + assert(NextOffset == P.EndOffset && "requestNext went backwards!"); } } // OK, we need to consume new slices. Set the end offset based on the // current slice, and step SJ past it. The beginning offset of the - // partition is the beginning offset of the next slice unless we have - // pre-existing split slices that are continuing, in which case we begin - // at the prior end offset. - P.BeginOffset = P.SplitTails.empty() ? P.SI->beginOffset() : P.EndOffset; - P.EndOffset = P.SI->endOffset(); + // partition is as close to the beginning offset of the next slice as we + // can get, unless we have pre-existing split slices that are continuing, + // in which case we begin at the prior end offset. + if (P.SplitTails.empty() && P.SI->beginOffset() > P.EndOffset) { + P.BeginOffset = R.requestNext(P.SI->beginOffset(), /*RoundUp=*/false); + R.commitNext(); + } else + P.BeginOffset = P.EndOffset; + P.EndOffset = R.requestNext(P.SI->endOffset(), /*RoundUp=*/true); + bool Splittable = P.SI->isSplittable(); ++P.SJ; - // There are two strategies to form a partition based on whether the - // partition starts with an unsplittable slice or a splittable slice. - if (!P.SI->isSplittable()) { - // When we're forming an unsplittable region, it must always start at - // the first slice and will extend through its end. - assert(P.BeginOffset == P.SI->beginOffset()); - - // Form a partition including all of the overlapping slices with this - // unsplittable slice. - while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset && - P.SJ->isAligned(P.BeginOffset)) { - if (!P.SJ->isSplittable()) - P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); - ++P.SJ; + // Collect all the overlapping slices and grow the partition if possible. + // If we encounter an unsplittable slice, try to stop before it, otherwise + // stop as soon after it as possible. + Align AlignReq = P.SI->minAlignment(); + while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { + assert((P.SJ->isSplittable() || P.SJ->endOffset() <= P.EndOffset) && + "requestNext tried to split unsplittable slice!"); + if (Splittable) { + if (!P.SJ->isSplittable()) { + Splittable = false; + if (P.SJ->beginOffset() > P.BeginOffset) { + P.EndOffset = R.requestNext(P.SJ->beginOffset(), /*RoundUp=*/false); + if (P.EndOffset > P.BeginOffset) + break; + } + } + if (!P.SJ->isSplittable() || P.SJ->endOffset() > P.EndOffset) + P.EndOffset = R.requestNext(P.SJ->endOffset(), /*RoundUp=*/true); } - - // We have a partition across a set of overlapping unsplittable - // partitions. - return; - } - - // If we're starting with a splittable slice, then we need to form - // a synthetic partition spanning it and any other overlapping splittable - // splices. - assert(P.SI->isSplittable() && "Forming a splittable partition!"); - - // Collect all of the overlapping splittable slices. - while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset && - P.SJ->isSplittable()) { - P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); + if (P.SJ->minAlignment() > AlignReq) + AlignReq = P.SJ->minAlignment(); ++P.SJ; } - // Back upiP.EndOffset if we ended the span early when encountering an - // unsplittable slice. This synthesizes the early end offset of - // a partition spanning only splittable slices. - if (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { - assert(!P.SJ->isSplittable()); - P.EndOffset = P.SJ->beginOffset(); + // If we encountered an unsplittable slice we may have truncated the end of + // the partition to before its start and need to back up SJ. + while (P.SJ > P.SI && (P.SJ - 1)->beginOffset() >= P.EndOffset) { + assert(!Splittable && "Unwinding splittable partition!"); + --P.SJ; } + assert(isAligned(AlignReq, P.BeginOffset) && + "requestNext tried to create unaligned slice!"); + R.commitNext(); } public: diff --git a/llvm/test/Transforms/SROA/cheri-crash-unaligned-overlap.ll b/llvm/test/Transforms/SROA/cheri-crash-unaligned-overlap.ll new file mode 100644 index 0000000000000..90f97eac38d6d --- /dev/null +++ b/llvm/test/Transforms/SROA/cheri-crash-unaligned-overlap.ll @@ -0,0 +1,106 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt -S -passes=sroa < %s | FileCheck %s +target datalayout = "e-m:e-pf200:128:128:128:64-p:64:64-i64:64-i128:128-n32:64-S128-A200-P200-G200" + +declare void @llvm.lifetime.start.p200(i64 immarg, ptr addrspace(200)) + +;; This previously crashed with: +;; +;; Assertion `BeginOffset < EndOffset && "Partitions must span some bytes!"' failed. +;; +;; This test is ultimately derived from specific usage within Chromium of: +;; +;; std::string("array of ") + (condition ? "mutable " : "") +;; +;; that was optimised to IR characterised by the partition access patterns in +;; this test and crashed Morello LLVM (note for example that the original +;; reduced IR loaded rather than stored a pointer, but the latter avoids issues +;; of poison, and the overlap has been reduced to just a single i16 at the +;; boundary). +define void @head_i16() { +; CHECK-LABEL: define void @head_i16() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF:%.*]] = alloca [32 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) [[BUF]]) +; CHECK-NEXT: [[BUF_16_P16_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 16 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_16_P16_SROA_IDX]], align 16 +; CHECK-NEXT: [[BUF_15_P15_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 15 +; CHECK-NEXT: store i16 0, ptr addrspace(200) [[BUF_15_P15_SROA_IDX]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + %p16 = getelementptr [32 x i8], ptr addrspace(200) %buf, i64 0, i64 16 + store ptr addrspace(200) null, ptr addrspace(200) %p16, align 16 + %p15 = getelementptr [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store i16 0, ptr addrspace(200) %p15, align 1 + ret void +} + +;; Test we don't crash even when the overlap is due to an unaligned capability +;; (even if the align 16 is UB here). +define void @head_cap_misaligned() { +; CHECK-LABEL: define void @head_cap_misaligned() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF:%.*]] = alloca [32 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) [[BUF]]) +; CHECK-NEXT: [[BUF_16_P16_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 16 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_16_P16_SROA_IDX]], align 16 +; CHECK-NEXT: [[BUF_15_P15_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 15 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_15_P15_SROA_IDX]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + %p16 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 16 + store ptr addrspace(200) null, ptr addrspace(200) %p16, align 16 + %p15 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store ptr addrspace(200) null, ptr addrspace(200) %p15, align 16 + ret void +} + +;; Test we don't crash even when the overlap is due to an unaligned capability +;; that's marked as unaligned. +define void @head_cap_unaligned() { +; CHECK-LABEL: define void @head_cap_unaligned() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF:%.*]] = alloca [32 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) [[BUF]]) +; CHECK-NEXT: [[BUF_16_P16_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 16 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_16_P16_SROA_IDX]], align 16 +; CHECK-NEXT: [[BUF_15_P15_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF]], i64 15 +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_15_P15_SROA_IDX]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + %p16 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 16 + store ptr addrspace(200) null, ptr addrspace(200) %p16, align 16 + %p15 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store ptr addrspace(200) null, ptr addrspace(200) %p15, align 1 + ret void +} + +;; Test we don't crash when the overlap is past the end (and in fact omit the +;; tail padding as implicit). +define void @tail_i16() { +; CHECK-LABEL: define void @tail_i16() addrspace(200) { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BUF_SROA_0:%.*]] = alloca [17 x i8], align 16, addrspace(200) +; CHECK-NEXT: call void @llvm.lifetime.start.p200(i64 17, ptr addrspace(200) [[BUF_SROA_0]]) +; CHECK-NEXT: store ptr addrspace(200) null, ptr addrspace(200) [[BUF_SROA_0]], align 16 +; CHECK-NEXT: [[BUF_SROA_0_15_P15_SROA_IDX1:%.*]] = getelementptr inbounds i8, ptr addrspace(200) [[BUF_SROA_0]], i64 15 +; CHECK-NEXT: store i16 0, ptr addrspace(200) [[BUF_SROA_0_15_P15_SROA_IDX1]], align 1 +; CHECK-NEXT: ret void +; +entry: + %buf = alloca [32 x i8], align 16, addrspace(200) + call void @llvm.lifetime.start.p200(i64 32, ptr addrspace(200) %buf) + store ptr addrspace(200) null, ptr addrspace(200) %buf, align 16 + %p15 = getelementptr inbounds [32 x i8], ptr addrspace(200) %buf, i64 0, i64 15 + store i16 0, ptr addrspace(200) %p15, align 1 + ret void +}