Skip to content

Commit 7e78c73

Browse files
committed
[LV] Mask off possibly aliasing vector lanes
When vectorising a loop that uses loads and stores, those pointers could overlap if their difference is less than the vector factor. For example, if address 20 is being stored to and address 23 is being loaded from, they overlap when the vector factor is 4 or higher. Currently LoopVectorize branches to a scalar loop in these cases with a runtime check. Howver if we construct a mask that disables the overlapping (aliasing) lanes then the vectorised loop can be safely entered, as long as the loads and stores are masked off.
1 parent 62e786a commit 7e78c73

20 files changed

+847
-24
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,12 @@ struct PointerDiffInfo {
491491
const SCEV *SinkStart;
492492
unsigned AccessSize;
493493
bool NeedsFreeze;
494+
bool WriteAfterRead;
494495

495496
PointerDiffInfo(const SCEV *SrcStart, const SCEV *SinkStart,
496-
unsigned AccessSize, bool NeedsFreeze)
497+
unsigned AccessSize, bool NeedsFreeze, bool WriteAfterRead)
497498
: SrcStart(SrcStart), SinkStart(SinkStart), AccessSize(AccessSize),
498-
NeedsFreeze(NeedsFreeze) {}
499+
NeedsFreeze(NeedsFreeze), WriteAfterRead(WriteAfterRead) {}
499500
};
500501

501502
/// Holds information about the memory runtime legality checks to verify

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ enum class TailFoldingStyle {
205205
DataWithEVL,
206206
};
207207

208+
enum class RTCheckStyle {
209+
/// Create runtime checks based on the difference between two pointers
210+
ScalarDifference,
211+
/// Form a mask based on elements which won't be a WAR or RAW hazard.
212+
UseSafeEltsMask,
213+
};
214+
208215
struct TailFoldingInfo {
209216
TargetLibraryInfo *TLI;
210217
LoopVectorizationLegality *LVL;
@@ -1354,6 +1361,11 @@ class TargetTransformInfo {
13541361
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
13551362
TTI::TargetCostKind CostKind) const;
13561363

1364+
/// \return true if a mask should be formed that disables lanes that could
1365+
/// alias between two pointers. The mask is created by the
1366+
/// loop_dependence_{war,raw}_mask intrinsics.
1367+
LLVM_ABI bool useSafeEltsMask() const;
1368+
13571369
/// \return The maximum interleave factor that any transform should try to
13581370
/// perform for this target. This number depends on the level of parallelism
13591371
/// and the number of execution units in the CPU.

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,8 @@ class TargetTransformInfoImplBase {
659659
return InstructionCost::getInvalid();
660660
}
661661

662+
virtual bool useSafeEltsMask() const { return false; }
663+
662664
virtual unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }
663665

664666
virtual InstructionCost getArithmeticInstrCost(

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,53 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
21642164
// Otherwise, fallback to default scalarization cost.
21652165
break;
21662166
}
2167+
case Intrinsic::loop_dependence_raw_mask:
2168+
case Intrinsic::loop_dependence_war_mask: {
2169+
InstructionCost Cost = 0;
2170+
Type *PtrTy = ICA.getArgTypes()[0];
2171+
bool IsReadAfterWrite = IID == Intrinsic::loop_dependence_raw_mask;
2172+
2173+
Cost +=
2174+
thisT()->getArithmeticInstrCost(Instruction::Sub, PtrTy, CostKind);
2175+
if (IsReadAfterWrite) {
2176+
IntrinsicCostAttributes AbsAttrs(Intrinsic::abs, PtrTy, {PtrTy}, {});
2177+
Cost += thisT()->getIntrinsicInstrCost(AbsAttrs, CostKind);
2178+
}
2179+
2180+
Cost +=
2181+
thisT()->getArithmeticInstrCost(Instruction::SDiv, PtrTy, CostKind);
2182+
Type *CmpTy =
2183+
getTLI()
2184+
->getSetCCResultType(
2185+
thisT()->getDataLayout(), RetTy->getContext(),
2186+
getTLI()->getValueType(thisT()->getDataLayout(), PtrTy))
2187+
.getTypeForEVT(RetTy->getContext());
2188+
Cost += thisT()->getCmpSelInstrCost(
2189+
BinaryOperator::ICmp, CmpTy, PtrTy,
2190+
IsReadAfterWrite ? CmpInst::ICMP_EQ : CmpInst::ICMP_SLE, CostKind);
2191+
2192+
// The deconstructed active lane mask
2193+
VectorType *RetTyVec = cast<VectorType>(RetTy);
2194+
VectorType *SplatTy = cast<VectorType>(RetTyVec->getWithNewType(PtrTy));
2195+
Cost += thisT()->getShuffleCost(TTI::SK_Broadcast, SplatTy, SplatTy, {},
2196+
CostKind, 0, nullptr);
2197+
IntrinsicCostAttributes StepVecAttrs(Intrinsic::stepvector, SplatTy, {},
2198+
FMF);
2199+
Cost += thisT()->getIntrinsicInstrCost(StepVecAttrs, CostKind);
2200+
Cost += thisT()->getCmpSelInstrCost(BinaryOperator::ICmp, SplatTy,
2201+
SplatTy, CmpInst::ICMP_ULT, CostKind);
2202+
2203+
Cost +=
2204+
thisT()->getCastInstrCost(Instruction::CastOps::ZExt, RetTy, SplatTy,
2205+
TTI::CastContextHint::None, CostKind);
2206+
Cost += thisT()->getCastInstrCost(Instruction::CastOps::ZExt,
2207+
RetTyVec->getElementType(), CmpTy,
2208+
TTI::CastContextHint::None, CostKind);
2209+
Cost += thisT()->getShuffleCost(TTI::SK_Broadcast, RetTyVec, RetTyVec, {},
2210+
CostKind, 0, nullptr);
2211+
Cost += thisT()->getArithmeticInstrCost(Instruction::Or, RetTy, CostKind);
2212+
return Cost;
2213+
}
21672214
}
21682215

21692216
// Assume that we need to scalarize this intrinsic.)

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,11 @@ addRuntimeChecks(Instruction *Loc, Loop *TheLoop,
580580
const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
581581
SCEVExpander &Expander, bool HoistRuntimeChecks = false);
582582

583+
LLVM_ABI Value *addSafeEltsRuntimeChecks(Instruction *Loc,
584+
ArrayRef<PointerDiffInfo> Checks,
585+
SCEVExpander &Expander,
586+
ElementCount VF);
587+
583588
LLVM_ABI Value *addDiffRuntimeChecks(
584589
Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
585590
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC);

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,11 +511,14 @@ bool RuntimePointerChecking::tryToCreateDiffCheck(
511511
}
512512
}
513513

514+
bool WriteAfterRead = !Src->IsWritePtr && Sink->IsWritePtr;
515+
514516
LLVM_DEBUG(dbgs() << "LAA: Creating diff runtime check for:\n"
515517
<< "SrcStart: " << *SrcStartInt << '\n'
516518
<< "SinkStartInt: " << *SinkStartInt << '\n');
517519
DiffChecks.emplace_back(SrcStartInt, SinkStartInt, AllocSize,
518-
Src->NeedsFreeze || Sink->NeedsFreeze);
520+
Src->NeedsFreeze || Sink->NeedsFreeze,
521+
WriteAfterRead);
519522
return true;
520523
}
521524

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,10 @@ InstructionCost TargetTransformInfo::getPartialReductionCost(
878878
BinOp, CostKind);
879879
}
880880

881+
bool TargetTransformInfo::useSafeEltsMask() const {
882+
return TTIImpl->useSafeEltsMask();
883+
}
884+
881885
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
882886
return TTIImpl->getMaxInterleaveFactor(VF);
883887
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,12 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
10041004
}
10051005
break;
10061006
}
1007+
case Intrinsic::loop_dependence_raw_mask:
1008+
case Intrinsic::loop_dependence_war_mask:
1009+
// The whilewr/rw instructions require SVE2
1010+
if (ST->hasSVE2() || ST->hasSME())
1011+
return 1;
1012+
break;
10071013
default:
10081014
break;
10091015
}
@@ -5725,6 +5731,11 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
57255731
return Cost + 4;
57265732
}
57275733

5734+
bool AArch64TTIImpl::useSafeEltsMask() const {
5735+
// The whilewr/rw instructions require SVE2
5736+
return ST->hasSVE2() || ST->hasSME();
5737+
}
5738+
57285739
InstructionCost
57295740
AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
57305741
VectorType *SrcTy, ArrayRef<int> Mask,

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
398398
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
399399
TTI::TargetCostKind CostKind) const override;
400400

401+
bool useSafeEltsMask() const override;
402+
401403
bool enableOrderedReductions() const override { return true; }
402404

403405
InstructionCost getInterleavedMemoryOpCost(

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,58 @@ Value *llvm::addRuntimeChecks(
20942094
return MemoryRuntimeCheck;
20952095
}
20962096

2097+
Value *llvm::addSafeEltsRuntimeChecks(Instruction *Loc,
2098+
ArrayRef<PointerDiffInfo> Checks,
2099+
SCEVExpander &Expander, ElementCount VF) {
2100+
IRBuilder ChkBuilder(Loc->getContext(),
2101+
InstSimplifyFolder(Loc->getDataLayout()));
2102+
ChkBuilder.SetInsertPoint(Loc);
2103+
Value *MemoryRuntimeCheck = nullptr;
2104+
2105+
// Map to keep track of created compares, The key is the pair of operands for
2106+
// the compare, to allow detecting and re-using redundant compares.
2107+
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2108+
Value *AliasLaneMask = nullptr;
2109+
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze,
2110+
WriteAfterRead] : Checks) {
2111+
Type *Ty = SinkStart->getType();
2112+
Value *Sink = Expander.expandCodeFor(SinkStart, Ty, Loc);
2113+
Value *Src = Expander.expandCodeFor(SrcStart, Ty, Loc);
2114+
if (SeenCompares.lookup({Sink, Src}))
2115+
continue;
2116+
2117+
unsigned IntOpc = WriteAfterRead ? Intrinsic::loop_dependence_war_mask
2118+
: Intrinsic::loop_dependence_raw_mask;
2119+
Value *SourceAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Src,
2120+
ChkBuilder.getPtrTy());
2121+
Value *SinkAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Sink,
2122+
ChkBuilder.getPtrTy());
2123+
Value *M = ChkBuilder.CreateIntrinsic(
2124+
IntOpc, {VectorType::get(ChkBuilder.getInt1Ty(), VF)},
2125+
{SourceAsPtr, SinkAsPtr, ChkBuilder.getInt64(AccessSize)}, nullptr,
2126+
"alias.lane.mask");
2127+
SeenCompares.insert({{Sink, Src}, M});
2128+
if (AliasLaneMask)
2129+
M = ChkBuilder.CreateAnd(AliasLaneMask, M);
2130+
else
2131+
AliasLaneMask = M;
2132+
}
2133+
assert(AliasLaneMask && "Expected an alias lane mask to have been created.");
2134+
auto *VecVT = VectorType::get(ChkBuilder.getInt1Ty(), VF);
2135+
// Extend to an i8 since i1 is too small to add with
2136+
Value *PopCount = ChkBuilder.CreateCast(
2137+
Instruction::ZExt, AliasLaneMask,
2138+
VectorType::get(ChkBuilder.getInt8Ty(), VecVT->getElementCount()));
2139+
2140+
PopCount =
2141+
ChkBuilder.CreateUnaryIntrinsic(Intrinsic::vector_reduce_add, PopCount);
2142+
PopCount = ChkBuilder.CreateCast(Instruction::ZExt, PopCount,
2143+
ChkBuilder.getInt64Ty());
2144+
MemoryRuntimeCheck = ChkBuilder.CreateICmpUGT(
2145+
PopCount, ConstantInt::get(ChkBuilder.getInt64Ty(), 0));
2146+
return MemoryRuntimeCheck;
2147+
}
2148+
20972149
Value *llvm::addDiffRuntimeChecks(
20982150
Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
20992151
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) {
@@ -2108,7 +2160,8 @@ Value *llvm::addDiffRuntimeChecks(
21082160
// Map to keep track of created compares, The key is the pair of operands for
21092161
// the compare, to allow detecting and re-using redundant compares.
21102162
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2111-
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
2163+
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze,
2164+
WriteAfterRead] : Checks) {
21122165
Type *Ty = SinkStart->getType();
21132166
// Compute VF * IC * AccessSize.
21142167
auto *VFTimesICTimesSize =
@@ -2117,8 +2170,8 @@ Value *llvm::addDiffRuntimeChecks(
21172170
Value *Diff =
21182171
Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
21192172

2120-
// Check if the same compare has already been created earlier. In that case,
2121-
// there is no need to check it again.
2173+
// Check if the same compare has already been created earlier. In that
2174+
// case, there is no need to check it again.
21222175
Value *IsConflict = SeenCompares.lookup({Diff, VFTimesICTimesSize});
21232176
if (IsConflict)
21242177
continue;

0 commit comments

Comments
 (0)