Skip to content

Commit fb7ebba

Browse files
committed
[LV] Mask off possibly aliasing vector lanes
When vectorising a loop that uses loads and stores, those pointers could overlap if their difference is less than the vector factor. For example, if address 20 is being stored to and address 23 is being loaded from, they overlap when the vector factor is 4 or higher. Currently LoopVectorize branches to a scalar loop in these cases with a runtime check. Howver if we construct a mask that disables the overlapping (aliasing) lanes then the vectorised loop can be safely entered, as long as the loads and stores are masked off.
1 parent eab11c8 commit fb7ebba

19 files changed

+794
-26
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,12 @@ struct PointerDiffInfo {
491491
const SCEV *SinkStart;
492492
unsigned AccessSize;
493493
bool NeedsFreeze;
494+
bool WriteAfterRead;
494495

495496
PointerDiffInfo(const SCEV *SrcStart, const SCEV *SinkStart,
496-
unsigned AccessSize, bool NeedsFreeze)
497+
unsigned AccessSize, bool NeedsFreeze, bool WriteAfterRead)
497498
: SrcStart(SrcStart), SinkStart(SinkStart), AccessSize(AccessSize),
498-
NeedsFreeze(NeedsFreeze) {}
499+
NeedsFreeze(NeedsFreeze), WriteAfterRead(WriteAfterRead) {}
499500
};
500501

501502
/// Holds information about the memory runtime legality checks to verify

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ enum class TailFoldingStyle {
205205
DataWithEVL,
206206
};
207207

208+
enum class RTCheckStyle {
209+
/// Create runtime checks based on the difference between two pointers
210+
ScalarDifference,
211+
/// Form a mask based on elements which won't be a WAR or RAW hazard.
212+
UseSafeEltsMask,
213+
};
214+
208215
struct TailFoldingInfo {
209216
TargetLibraryInfo *TLI;
210217
LoopVectorizationLegality *LVL;
@@ -1357,6 +1364,11 @@ class TargetTransformInfo {
13571364
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
13581365
TTI::TargetCostKind CostKind) const;
13591366

1367+
/// \return true if a mask should be formed that disables lanes that could
1368+
/// alias between two pointers. The mask is created by the
1369+
/// loop_dependence_{war,raw}_mask intrinsics.
1370+
LLVM_ABI bool useSafeEltsMask() const;
1371+
13601372
/// \return The maximum interleave factor that any transform should try to
13611373
/// perform for this target. This number depends on the level of parallelism
13621374
/// and the number of execution units in the CPU.

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,8 @@ class TargetTransformInfoImplBase {
659659
return InstructionCost::getInvalid();
660660
}
661661

662+
virtual bool useSafeEltsMask() const { return false; }
663+
662664
virtual unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
663665

664666
virtual InstructionCost getArithmeticInstrCost(

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,11 @@ addRuntimeChecks(Instruction *Loc, Loop *TheLoop,
614614
const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
615615
SCEVExpander &Expander, bool HoistRuntimeChecks = false);
616616

617+
LLVM_ABI Value *addSafeEltsRuntimeChecks(Instruction *Loc,
618+
ArrayRef<PointerDiffInfo> Checks,
619+
SCEVExpander &Expander,
620+
ElementCount VF);
621+
617622
LLVM_ABI Value *addDiffRuntimeChecks(
618623
Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
619624
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC);

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,14 @@ bool RuntimePointerChecking::tryToCreateDiffCheck(
511511
}
512512
}
513513

514+
bool WriteAfterRead = !Src->IsWritePtr && Sink->IsWritePtr;
515+
514516
LLVM_DEBUG(dbgs() << "LAA: Creating diff runtime check for:\n"
515517
<< "SrcStart: " << *SrcStartInt << '\n'
516518
<< "SinkStartInt: " << *SinkStartInt << '\n');
517519
DiffChecks.emplace_back(SrcStartInt, SinkStartInt, AllocSize,
518-
Src->NeedsFreeze || Sink->NeedsFreeze);
520+
Src->NeedsFreeze || Sink->NeedsFreeze,
521+
WriteAfterRead);
519522
return true;
520523
}
521524

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,10 @@ InstructionCost TargetTransformInfo::getPartialReductionCost(
878878
BinOp, CostKind);
879879
}
880880

881+
bool TargetTransformInfo::useSafeEltsMask() const {
882+
return TTIImpl->useSafeEltsMask();
883+
}
884+
881885
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
882886
return TTIImpl->getMaxInterleaveFactor(VF);
883887
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,8 +1073,6 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
10731073
return LegalCost;
10741074
}
10751075
break;
1076-
default:
1077-
break;
10781076
}
10791077
return BaseT::getIntrinsicInstrCost(ICA, CostKind);
10801078
}
@@ -5880,6 +5878,11 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
58805878
return Cost + 2;
58815879
}
58825880

5881+
bool AArch64TTIImpl::useSafeEltsMask() const {
5882+
// The whilewr/rw instructions require SVE2
5883+
return ST->hasSVE2() || ST->hasSME();
5884+
}
5885+
58835886
InstructionCost
58845887
AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
58855888
VectorType *SrcTy, ArrayRef<int> Mask,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
406406
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
407407
TTI::TargetCostKind CostKind) const override;
408408

409+
bool useSafeEltsMask() const override;
410+
409411
bool enableOrderedReductions() const override { return true; }
410412

411413
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,6 +2135,58 @@ Value *llvm::addRuntimeChecks(
21352135
return MemoryRuntimeCheck;
21362136
}
21372137

2138+
Value *llvm::addSafeEltsRuntimeChecks(Instruction *Loc,
2139+
ArrayRef<PointerDiffInfo> Checks,
2140+
SCEVExpander &Expander, ElementCount VF) {
2141+
IRBuilder ChkBuilder(Loc->getContext(),
2142+
InstSimplifyFolder(Loc->getDataLayout()));
2143+
ChkBuilder.SetInsertPoint(Loc);
2144+
Value *MemoryRuntimeCheck = nullptr;
2145+
2146+
// Map to keep track of created compares, The key is the pair of operands for
2147+
// the compare, to allow detecting and re-using redundant compares.
2148+
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2149+
Value *AliasLaneMask = nullptr;
2150+
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze,
2151+
WriteAfterRead] : Checks) {
2152+
Type *Ty = SinkStart->getType();
2153+
Value *Sink = Expander.expandCodeFor(SinkStart, Ty, Loc);
2154+
Value *Src = Expander.expandCodeFor(SrcStart, Ty, Loc);
2155+
if (SeenCompares.lookup({Sink, Src}))
2156+
continue;
2157+
2158+
unsigned IntOpc = WriteAfterRead ? Intrinsic::loop_dependence_war_mask
2159+
: Intrinsic::loop_dependence_raw_mask;
2160+
Value *SourceAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Src,
2161+
ChkBuilder.getPtrTy());
2162+
Value *SinkAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Sink,
2163+
ChkBuilder.getPtrTy());
2164+
Value *M = ChkBuilder.CreateIntrinsic(
2165+
IntOpc, {VectorType::get(ChkBuilder.getInt1Ty(), VF)},
2166+
{SourceAsPtr, SinkAsPtr, ChkBuilder.getInt64(AccessSize)}, nullptr,
2167+
"alias.lane.mask");
2168+
SeenCompares.insert({{Sink, Src}, M});
2169+
if (AliasLaneMask)
2170+
M = ChkBuilder.CreateAnd(AliasLaneMask, M);
2171+
else
2172+
AliasLaneMask = M;
2173+
}
2174+
assert(AliasLaneMask && "Expected an alias lane mask to have been created.");
2175+
auto *VecVT = VectorType::get(ChkBuilder.getInt1Ty(), VF);
2176+
// Extend to an i8 since i1 is too small to add with
2177+
Value *PopCount = ChkBuilder.CreateCast(
2178+
Instruction::ZExt, AliasLaneMask,
2179+
VectorType::get(ChkBuilder.getInt8Ty(), VecVT->getElementCount()));
2180+
2181+
PopCount =
2182+
ChkBuilder.CreateUnaryIntrinsic(Intrinsic::vector_reduce_add, PopCount);
2183+
PopCount = ChkBuilder.CreateCast(Instruction::ZExt, PopCount,
2184+
ChkBuilder.getInt64Ty());
2185+
MemoryRuntimeCheck = ChkBuilder.CreateICmpUGT(
2186+
PopCount, ConstantInt::get(ChkBuilder.getInt64Ty(), 0));
2187+
return MemoryRuntimeCheck;
2188+
}
2189+
21382190
Value *llvm::addDiffRuntimeChecks(
21392191
Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
21402192
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) {
@@ -2149,7 +2201,8 @@ Value *llvm::addDiffRuntimeChecks(
21492201
// Map to keep track of created compares, The key is the pair of operands for
21502202
// the compare, to allow detecting and re-using redundant compares.
21512203
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2152-
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
2204+
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze,
2205+
WriteAfterRead] : Checks) {
21532206
Type *Ty = SinkStart->getType();
21542207
// Compute VF * IC * AccessSize.
21552208
auto *VFTimesICTimesSize =
@@ -2158,8 +2211,8 @@ Value *llvm::addDiffRuntimeChecks(
21582211
Value *Diff =
21592212
Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
21602213

2161-
// Check if the same compare has already been created earlier. In that case,
2162-
// there is no need to check it again.
2214+
// Check if the same compare has already been created earlier. In that
2215+
// case, there is no need to check it again.
21632216
Value *IsConflict = SeenCompares.lookup({Diff, VFTimesICTimesSize});
21642217
if (IsConflict)
21652218
continue;

llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,11 @@ class LoopVectorizationPlanner {
506506
/// Build VPlans for the specified \p UserVF and \p UserIC if they are
507507
/// non-zero or all applicable candidate VFs otherwise. If vectorization and
508508
/// interleaving should be avoided up-front, no plans are generated.
509-
void plan(ElementCount UserVF, unsigned UserIC);
509+
/// DiffChecks is a list of pointer pairs that should be checked for aliasing,
510+
/// combining the resulting predicate with an active lane mask if one is in
511+
/// use.
512+
void plan(ElementCount UserVF, unsigned UserIC,
513+
std::optional<ArrayRef<PointerDiffInfo>> DiffChecks);
510514

511515
/// Use the VPlan-native path to plan how to best vectorize, return the best
512516
/// VF and its cost.

0 commit comments

Comments
 (0)