Skip to content

[InferAlignment] Propagate alignment between loads/stores of the same base pointer #145733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions llvm/lib/Transforms/Scalar/InferAlignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,43 @@

using namespace llvm;

static bool tryToPropagateAlign(Function &F, const DataLayout &DL) {
bool Changed = false;

for (BasicBlock &BB : F) {
// We need to reset the map for each block because alignment information
// can't be propagated across blocks. This is because control flow could
// be dependent on the address at runtime, making an alignment assumption
// within one block not true in another. Some sort of dominator tree
// approach could be better, but restricting within a basic block is correct
// too.
DenseMap<Value *, Align> BestBasePointerAligns;
for (Instruction &I : BB) {
if (auto *PtrOp = getLoadStorePointerOperand(&I)) {
Align LoadStoreAlign = getLoadStoreAlignment(&I);
APInt OffsetFromBase = APInt(
DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()),
0);
PtrOp = PtrOp->stripAndAccumulateInBoundsConstantOffsets(
DL, OffsetFromBase);
Align BasePointerAlign =
commonAlignment(LoadStoreAlign, OffsetFromBase.getLimitedValue());

if (BestBasePointerAligns.count(PtrOp) &&
BestBasePointerAligns[PtrOp] > BasePointerAlign) {
Align BetterLoadStoreAlign = commonAlignment(
BestBasePointerAligns[PtrOp], OffsetFromBase.getLimitedValue());
setLoadStoreAlignment(&I, BetterLoadStoreAlign);
Changed = true;
} else {
BestBasePointerAligns[PtrOp] = BasePointerAlign;
}
}
}
}
return Changed;
}

static bool tryToImproveAlign(
const DataLayout &DL, Instruction *I,
function_ref<Align(Value *PtrOp, Align OldAlign, Align PrefAlign)> Fn) {
Expand Down Expand Up @@ -70,6 +107,10 @@ bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
}
}

// Propagate alignment between loads and stores that originate from the same
// base pointer
Changed |= tryToPropagateAlign(F, DL);

return Changed;
}

Expand Down
41 changes: 33 additions & 8 deletions llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ class Vectorizer {
/// Postcondition: For all i, ret[i][0].second == 0, because the first instr
/// in the chain is the leader, and an instr touches distance 0 from itself.
std::vector<Chain> gatherChains(ArrayRef<Instruction *> Instrs);

/// Propagates the best alignment in a chain of contiguous accesses
void propagateBestAlignmentInChain(ArrayRef<ChainElem> C) const;
};

class LoadStoreVectorizerLegacyPass : public FunctionPass {
Expand Down Expand Up @@ -716,6 +719,14 @@ std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
unsigned AS = getLoadStoreAddressSpace(C[0].Inst);
unsigned VecRegBytes = TTI.getLoadStoreVecRegBitWidth(AS) / 8;

// We know that the accesses are contiguous. Propagate alignment
// information so that slices of the chain can still be vectorized.
propagateBestAlignmentInChain(C);
LLVM_DEBUG({
dbgs() << "LSV: Chain after alignment propagation:\n";
dumpChain(C);
});

std::vector<Chain> Ret;
for (unsigned CBegin = 0; CBegin < C.size(); ++CBegin) {
// Find candidate chains of size not greater than the largest vector reg.
Expand Down Expand Up @@ -823,6 +834,7 @@ std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
<< Alignment.value() << " to " << NewAlign.value()
<< "\n");
Alignment = NewAlign;
setLoadStoreAlignment(C[CBegin].Inst, Alignment);
}
}

Expand Down Expand Up @@ -880,14 +892,6 @@ bool Vectorizer::vectorizeChain(Chain &C) {
VecElemTy, 8 * ChainBytes / DL.getTypeSizeInBits(VecElemTy));

Align Alignment = getLoadStoreAlignment(C[0].Inst);
// If this is a load/store of an alloca, we might have upgraded the alloca's
// alignment earlier. Get the new alignment.
if (AS == DL.getAllocaAddrSpace()) {
Alignment = std::max(
Alignment,
getOrEnforceKnownAlignment(getLoadStorePointerOperand(C[0].Inst),
MaybeAlign(), DL, C[0].Inst, nullptr, &DT));
}

// All elements of the chain must have the same scalar-type size.
#ifndef NDEBUG
Expand Down Expand Up @@ -1634,3 +1638,24 @@ std::optional<APInt> Vectorizer::getConstantOffset(Value *PtrA, Value *PtrB,
.sextOrTrunc(OrigBitWidth);
return std::nullopt;
}

void Vectorizer::propagateBestAlignmentInChain(ArrayRef<ChainElem> C) const {
// Find the element in the chain with the best alignment and its offset.
Align BestAlign = getLoadStoreAlignment(C[0].Inst);
APInt BestAlignOffset = C[0].OffsetFromLeader;
for (const ChainElem &Elem : C) {
Align ElemAlign = getLoadStoreAlignment(Elem.Inst);
if (ElemAlign > BestAlign) {
BestAlign = ElemAlign;
BestAlignOffset = Elem.OffsetFromLeader;
}
}

// Propagate the best alignment to other elements in the chain, if possible.
for (const ChainElem &Elem : C) {
APInt OffsetDelta = APIntOps::abdu(Elem.OffsetFromLeader, BestAlignOffset);
Align NewAlign = commonAlignment(BestAlign, OffsetDelta.getLimitedValue());
if (NewAlign > getLoadStoreAlignment(Elem.Inst))
setLoadStoreAlignment(Elem.Inst, NewAlign);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt < %s -passes=infer-alignment -S | FileCheck %s
%struct.S1 = type { %struct.float3, %struct.float3, i32, i32 }
%struct.float3 = type { float, float, float }


; ------------------------------------------------------------------------------
; Test that we can propagate the align 16 to the load and store that are set to align 4
; ------------------------------------------------------------------------------

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
define void @prop_align(ptr noundef readonly captures(none) %v, ptr noundef writeonly captures(none) initializes((0, 32)) %vout) local_unnamed_addr #0 {
; CHECK-LABEL: define void @prop_align(
; CHECK-SAME: ptr noundef readonly captures(none) [[V:%.*]], ptr noundef writeonly captures(none) initializes((0, 32)) [[VOUT:%.*]]) local_unnamed_addr {
; CHECK-NEXT: [[DOTUNPACK_UNPACK:%.*]] = load float, ptr [[V]], align 16
; CHECK-NEXT: [[DOTUNPACK_ELT7:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 4
; CHECK-NEXT: [[DOTUNPACK_UNPACK8:%.*]] = load float, ptr [[DOTUNPACK_ELT7]], align 4
; CHECK-NEXT: [[DOTUNPACK_ELT9:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 8
; CHECK-NEXT: [[DOTUNPACK_UNPACK10:%.*]] = load float, ptr [[DOTUNPACK_ELT9]], align 8
; CHECK-NEXT: [[DOTELT1:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 12
; CHECK-NEXT: [[DOTUNPACK2_UNPACK:%.*]] = load float, ptr [[DOTELT1]], align 4
; CHECK-NEXT: [[DOTUNPACK2_ELT12:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 16
; CHECK-NEXT: [[DOTUNPACK2_UNPACK13:%.*]] = load float, ptr [[DOTUNPACK2_ELT12]], align 16
; CHECK-NEXT: [[DOTUNPACK2_ELT14:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 20
; CHECK-NEXT: [[DOTUNPACK2_UNPACK15:%.*]] = load float, ptr [[DOTUNPACK2_ELT14]], align 4
; CHECK-NEXT: [[DOTELT3:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 24
; CHECK-NEXT: [[DOTUNPACK4:%.*]] = load i32, ptr [[DOTELT3]], align 8
; CHECK-NEXT: [[DOTELT5:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 28
; CHECK-NEXT: [[DOTUNPACK6:%.*]] = load i32, ptr [[DOTELT5]], align 4
; CHECK-NEXT: store float [[DOTUNPACK_UNPACK]], ptr [[VOUT]], align 16
; CHECK-NEXT: [[VOUT_REPACK23:%.*]] = getelementptr inbounds nuw i8, ptr [[VOUT]], i64 4
; CHECK-NEXT: store float [[DOTUNPACK_UNPACK8]], ptr [[VOUT_REPACK23]], align 4
; CHECK-NEXT: [[VOUT_REPACK25:%.*]] = getelementptr inbounds nuw i8, ptr [[VOUT]], i64 8
; CHECK-NEXT: store float [[DOTUNPACK_UNPACK10]], ptr [[VOUT_REPACK25]], align 8
; CHECK-NEXT: [[VOUT_REPACK17:%.*]] = getelementptr inbounds nuw i8, ptr [[VOUT]], i64 12
; CHECK-NEXT: store float [[DOTUNPACK2_UNPACK]], ptr [[VOUT_REPACK17]], align 4
; CHECK-NEXT: [[VOUT_REPACK17_REPACK27:%.*]] = getelementptr inbounds nuw i8, ptr [[VOUT]], i64 16
; CHECK-NEXT: store float [[DOTUNPACK2_UNPACK13]], ptr [[VOUT_REPACK17_REPACK27]], align 16
; CHECK-NEXT: [[VOUT_REPACK17_REPACK29:%.*]] = getelementptr inbounds nuw i8, ptr [[VOUT]], i64 20
; CHECK-NEXT: store float [[DOTUNPACK2_UNPACK15]], ptr [[VOUT_REPACK17_REPACK29]], align 4
; CHECK-NEXT: [[VOUT_REPACK19:%.*]] = getelementptr inbounds nuw i8, ptr [[VOUT]], i64 24
; CHECK-NEXT: store i32 [[DOTUNPACK4]], ptr [[VOUT_REPACK19]], align 8
; CHECK-NEXT: [[VOUT_REPACK21:%.*]] = getelementptr inbounds nuw i8, ptr [[VOUT]], i64 28
; CHECK-NEXT: store i32 [[DOTUNPACK6]], ptr [[VOUT_REPACK21]], align 4
; CHECK-NEXT: ret void
;
%.unpack.unpack = load float, ptr %v, align 16
%.unpack.elt7 = getelementptr inbounds nuw i8, ptr %v, i64 4
%.unpack.unpack8 = load float, ptr %.unpack.elt7, align 4
%.unpack.elt9 = getelementptr inbounds nuw i8, ptr %v, i64 8
%.unpack.unpack10 = load float, ptr %.unpack.elt9, align 8
%.elt1 = getelementptr inbounds nuw i8, ptr %v, i64 12
%.unpack2.unpack = load float, ptr %.elt1, align 4
%.unpack2.elt12 = getelementptr inbounds nuw i8, ptr %v, i64 16
%.unpack2.unpack13 = load float, ptr %.unpack2.elt12, align 4
%.unpack2.elt14 = getelementptr inbounds nuw i8, ptr %v, i64 20
%.unpack2.unpack15 = load float, ptr %.unpack2.elt14, align 4
%.elt3 = getelementptr inbounds nuw i8, ptr %v, i64 24
%.unpack4 = load i32, ptr %.elt3, align 8
%.elt5 = getelementptr inbounds nuw i8, ptr %v, i64 28
%.unpack6 = load i32, ptr %.elt5, align 4
store float %.unpack.unpack, ptr %vout, align 16
%vout.repack23 = getelementptr inbounds nuw i8, ptr %vout, i64 4
store float %.unpack.unpack8, ptr %vout.repack23, align 4
%vout.repack25 = getelementptr inbounds nuw i8, ptr %vout, i64 8
store float %.unpack.unpack10, ptr %vout.repack25, align 8
%vout.repack17 = getelementptr inbounds nuw i8, ptr %vout, i64 12
store float %.unpack2.unpack, ptr %vout.repack17, align 4
%vout.repack17.repack27 = getelementptr inbounds nuw i8, ptr %vout, i64 16
store float %.unpack2.unpack13, ptr %vout.repack17.repack27, align 4
%vout.repack17.repack29 = getelementptr inbounds nuw i8, ptr %vout, i64 20
store float %.unpack2.unpack15, ptr %vout.repack17.repack29, align 4
%vout.repack19 = getelementptr inbounds nuw i8, ptr %vout, i64 24
store i32 %.unpack4, ptr %vout.repack19, align 8
%vout.repack21 = getelementptr inbounds nuw i8, ptr %vout, i64 28
store i32 %.unpack6, ptr %vout.repack21, align 4
ret void
}

; ------------------------------------------------------------------------------
; Test that alignment is not propagated from a source that does not dominate the destination
; ------------------------------------------------------------------------------

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
define void @no_prop_align(ptr noundef readonly captures(none) %v, ptr noundef writeonly captures(none) initializes((0, 32)) %vout, i1 %cond) local_unnamed_addr #0 {
; CHECK-LABEL: define void @no_prop_align(
; CHECK-SAME: ptr noundef readonly captures(none) [[V:%.*]], ptr noundef writeonly captures(none) initializes((0, 32)) [[VOUT:%.*]], i1 [[COND:%.*]]) local_unnamed_addr {
; CHECK-NEXT: br i1 [[COND]], label %[[BRANCH1:.*]], label %[[BRANCH2:.*]]
; CHECK: [[BRANCH1]]:
; CHECK-NEXT: [[DOTUNPACK_UNPACK:%.*]] = load float, ptr [[V]], align 16
; CHECK-NEXT: [[DOTUNPACK_ELT7:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 4
; CHECK-NEXT: [[DOTUNPACK_UNPACK8:%.*]] = load float, ptr [[DOTUNPACK_ELT7]], align 4
; CHECK-NEXT: [[DOTUNPACK_ELT9:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 8
; CHECK-NEXT: [[DOTUNPACK_UNPACK10:%.*]] = load float, ptr [[DOTUNPACK_ELT9]], align 8
; CHECK-NEXT: [[DOTELT1:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 12
; CHECK-NEXT: [[DOTUNPACK2_UNPACK:%.*]] = load float, ptr [[DOTELT1]], align 4
; CHECK-NEXT: br label %[[END:.*]]
; CHECK: [[BRANCH2]]:
; CHECK-NEXT: [[DOTUNPACK2_ELT12:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 16
; CHECK-NEXT: [[DOTUNPACK2_UNPACK13:%.*]] = load float, ptr [[DOTUNPACK2_ELT12]], align 4
; CHECK-NEXT: [[DOTUNPACK2_ELT14:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 20
; CHECK-NEXT: [[DOTUNPACK2_UNPACK15:%.*]] = load float, ptr [[DOTUNPACK2_ELT14]], align 4
; CHECK-NEXT: [[DOTELT3:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 24
; CHECK-NEXT: [[DOTUNPACK4:%.*]] = load i32, ptr [[DOTELT3]], align 8
; CHECK-NEXT: [[DOTELT5:%.*]] = getelementptr inbounds nuw i8, ptr [[V]], i64 28
; CHECK-NEXT: [[DOTUNPACK6:%.*]] = load i32, ptr [[DOTELT5]], align 4
; CHECK-NEXT: br label %[[END]]
; CHECK: [[END]]:
; CHECK-NEXT: ret void
;
br i1 %cond, label %branch1, label %branch2

branch1:
%.unpack.unpack = load float, ptr %v, align 16
%.unpack.elt7 = getelementptr inbounds nuw i8, ptr %v, i64 4
%.unpack.unpack8 = load float, ptr %.unpack.elt7, align 4
%.unpack.elt9 = getelementptr inbounds nuw i8, ptr %v, i64 8
%.unpack.unpack10 = load float, ptr %.unpack.elt9, align 8
%.elt1 = getelementptr inbounds nuw i8, ptr %v, i64 12
%.unpack2.unpack = load float, ptr %.elt1, align 4
br label %end

branch2:
%.unpack2.elt12 = getelementptr inbounds nuw i8, ptr %v, i64 16
%.unpack2.unpack13 = load float, ptr %.unpack2.elt12, align 4
%.unpack2.elt14 = getelementptr inbounds nuw i8, ptr %v, i64 20
%.unpack2.unpack15 = load float, ptr %.unpack2.elt14, align 4
%.elt3 = getelementptr inbounds nuw i8, ptr %v, i64 24
%.unpack4 = load i32, ptr %.elt3, align 8
%.elt5 = getelementptr inbounds nuw i8, ptr %v, i64 28
%.unpack6 = load i32, ptr %.elt5, align 4
br label %end

end:
ret void
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ define void @variadics1(ptr %vlist) {
; CHECK-NEXT: [[ARGP_NEXT12:%.*]] = getelementptr i8, ptr [[ARGP_CUR11_ALIGNED]], i64 8
; CHECK-NEXT: [[X2:%.*]] = getelementptr i8, ptr [[ARGP_NEXT12]], i32 7
; CHECK-NEXT: [[ARGP_CUR16_ALIGNED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[X2]], i64 0)
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, ptr [[ARGP_CUR16_ALIGNED]], align 4294967296
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x double>, ptr [[ARGP_CUR16_ALIGNED]], align 8
; CHECK-NEXT: [[X31:%.*]] = extractelement <2 x double> [[TMP1]], i32 0
; CHECK-NEXT: [[X42:%.*]] = extractelement <2 x double> [[TMP1]], i32 1
; CHECK-NEXT: [[X5:%.*]] = fadd double [[X42]], [[X31]]
Expand Down
Loading
Loading