Skip to content

Commit 2a46ad6

Browse files
committed
Transform the gather to stride load
1 parent 4a78f18 commit 2a46ad6

File tree

10 files changed

+568
-217
lines changed

10 files changed

+568
-217
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8424,20 +8424,15 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84248424
*Plan))
84258425
return nullptr;
84268426

8427+
VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind,
8428+
*CM.PSE.getSE());
84278429
// Transform recipes to abstract recipes if it is legal and beneficial and
84288430
// clamp the range for better cost estimation.
84298431
// TODO: Enable following transform when the EVL-version of extended-reduction
84308432
// and mulacc-reduction are implemented.
8431-
if (!CM.foldTailWithEVL()) {
8432-
VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind,
8433-
*CM.PSE.getSE());
8433+
if (!CM.foldTailWithEVL())
84348434
VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
84358435
CostCtx, Range);
8436-
}
8437-
8438-
for (ElementCount VF : Range)
8439-
Plan->addVF(VF);
8440-
Plan->setName("Initial VPlan");
84418436

84428437
// Interleave memory: for each Interleave Group we marked earlier as relevant
84438438
// for this VPlan, replace the Recipes widening its memory instructions with a
@@ -8450,6 +8445,15 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84508445
VPlanTransforms::runPass(VPlanTransforms::replaceSymbolicStrides, *Plan, PSE,
84518446
Legal->getLAI()->getSymbolicStrides());
84528447

8448+
// Convert memory recipes to strided access recipes if the strided access is
8449+
// legal and profitable.
8450+
VPlanTransforms::runPass(VPlanTransforms::convertToStridedAccesses, *Plan,
8451+
CostCtx, Range);
8452+
8453+
for (ElementCount VF : Range)
8454+
Plan->addVF(VF);
8455+
Plan->setName("Initial VPlan");
8456+
84538457
auto BlockNeedsPredication = [this](BasicBlock *BB) {
84548458
return Legal->blockNeedsPredication(BB);
84558459
};

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,10 +1774,6 @@ struct LLVM_ABI_FOR_TEST VPWidenSelectRecipe : public VPRecipeWithIRFlags,
17741774
class LLVM_ABI_FOR_TEST VPWidenGEPRecipe : public VPRecipeWithIRFlags {
17751775
Type *SourceElementTy;
17761776

1777-
bool isPointerLoopInvariant() const {
1778-
return getOperand(0)->isDefinedOutsideLoopRegions();
1779-
}
1780-
17811777
bool isIndexLoopInvariant(unsigned I) const {
17821778
return getOperand(I + 1)->isDefinedOutsideLoopRegions();
17831779
}
@@ -1810,6 +1806,29 @@ class LLVM_ABI_FOR_TEST VPWidenGEPRecipe : public VPRecipeWithIRFlags {
18101806
/// This recipe generates a GEP instruction.
18111807
unsigned getOpcode() const { return Instruction::GetElementPtr; }
18121808

1809+
bool isPointerLoopInvariant() const {
1810+
return getOperand(0)->isDefinedOutsideLoopRegions();
1811+
}
1812+
1813+
std::optional<unsigned> getUniqueVariantIndex() const {
1814+
std::optional<unsigned> VarIdx;
1815+
for (unsigned I = 0, E = getNumOperands() - 1; I < E; ++I) {
1816+
if (isIndexLoopInvariant(I))
1817+
continue;
1818+
1819+
if (VarIdx)
1820+
return std::nullopt;
1821+
VarIdx = I;
1822+
}
1823+
return VarIdx;
1824+
}
1825+
1826+
Type *getIndexedType(unsigned I) const {
1827+
auto *GEP = cast<GetElementPtrInst>(getUnderlyingInstr());
1828+
SmallVector<Value *, 4> Ops(GEP->idx_begin(), GEP->idx_begin() + I);
1829+
return GetElementPtrInst::getIndexedType(SourceElementTy, Ops);
1830+
}
1831+
18131832
/// Generate the gep nodes.
18141833
void execute(VPTransformState &State) override;
18151834

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4492,3 +4492,184 @@ void VPlanTransforms::addExitUsersForFirstOrderRecurrences(VPlan &Plan,
44924492
}
44934493
}
44944494
}
4495+
4496+
static std::pair<VPValue *, VPValue *> matchStridedStart(VPValue *CurIndex) {
4497+
// TODO: Support VPWidenPointerInductionRecipe.
4498+
if (auto *WidenIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(CurIndex))
4499+
return {WidenIV, WidenIV->getStepValue()};
4500+
4501+
auto *WidenR = dyn_cast<VPWidenRecipe>(CurIndex);
4502+
if (!WidenR || !CurIndex->getUnderlyingValue())
4503+
return {nullptr, nullptr};
4504+
4505+
unsigned Opcode = WidenR->getOpcode();
4506+
// TODO: Support Instruction::Add and Instruction::Or.
4507+
if (Opcode != Instruction::Shl && Opcode != Instruction::Mul)
4508+
return {nullptr, nullptr};
4509+
4510+
// Match the pattern binop(variant, invariant), or binop(invariant, variant)
4511+
// if the binary operator is commutative.
4512+
bool IsLHSUniform = vputils::isSingleScalar(WidenR->getOperand(0));
4513+
if (IsLHSUniform == vputils::isSingleScalar(WidenR->getOperand(1)) ||
4514+
(IsLHSUniform && !Instruction::isCommutative(Opcode)))
4515+
return {nullptr, nullptr};
4516+
unsigned VarIdx = IsLHSUniform ? 1 : 0;
4517+
4518+
auto [Start, Stride] = matchStridedStart(WidenR->getOperand(VarIdx));
4519+
if (!Start)
4520+
return {nullptr, nullptr};
4521+
4522+
SmallVector<VPValue *> StartOps(WidenR->operands());
4523+
StartOps[VarIdx] = Start;
4524+
auto *StartR = new VPReplicateRecipe(WidenR->getUnderlyingInstr(), StartOps,
4525+
/*IsUniform*/ true);
4526+
StartR->insertBefore(WidenR);
4527+
4528+
unsigned InvIdx = VarIdx == 0 ? 1 : 0;
4529+
auto *StrideR =
4530+
new VPInstruction(Opcode, {Stride, WidenR->getOperand(InvIdx)});
4531+
StrideR->insertBefore(WidenR);
4532+
return {StartR, StrideR};
4533+
}
4534+
4535+
static std::tuple<VPValue *, VPValue *, Type *>
4536+
determineBaseAndStride(VPWidenGEPRecipe *WidenGEP) {
4537+
// TODO: Check if the base pointer is strided.
4538+
if (!WidenGEP->isPointerLoopInvariant())
4539+
return {nullptr, nullptr, nullptr};
4540+
4541+
// Find the only one variant index.
4542+
std::optional<unsigned> VarIndex = WidenGEP->getUniqueVariantIndex();
4543+
if (!VarIndex)
4544+
return {nullptr, nullptr, nullptr};
4545+
4546+
Type *ElementTy = WidenGEP->getIndexedType(*VarIndex);
4547+
if (ElementTy->isScalableTy() || ElementTy->isStructTy() ||
4548+
ElementTy->isVectorTy())
4549+
return {nullptr, nullptr, nullptr};
4550+
4551+
unsigned VarOp = *VarIndex + 1;
4552+
VPValue *IndexVPV = WidenGEP->getOperand(VarOp);
4553+
auto [Start, Stride] = matchStridedStart(IndexVPV);
4554+
if (!Start)
4555+
return {nullptr, nullptr, nullptr};
4556+
4557+
SmallVector<VPValue *> Ops(WidenGEP->operands());
4558+
Ops[VarOp] = Start;
4559+
auto *BasePtr = new VPReplicateRecipe(WidenGEP->getUnderlyingInstr(), Ops,
4560+
/*IsUniform*/ true);
4561+
BasePtr->insertBefore(WidenGEP);
4562+
4563+
return {BasePtr, Stride, ElementTy};
4564+
}
4565+
4566+
void VPlanTransforms::convertToStridedAccesses(VPlan &Plan, VPCostContext &Ctx,
4567+
VFRange &Range) {
4568+
if (Plan.hasScalarVFOnly())
4569+
return;
4570+
4571+
VPTypeAnalysis TypeInfo(Plan);
4572+
DenseMap<VPWidenGEPRecipe *, std::tuple<VPValue *, VPValue *, Type *>>
4573+
StrideCache;
4574+
SmallVector<VPRecipeBase *> ToErase;
4575+
SmallPtrSet<VPValue *, 4> PossiblyDead;
4576+
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
4577+
vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) {
4578+
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
4579+
auto *MemR = dyn_cast<VPWidenMemoryRecipe>(&R);
4580+
// TODO: Support strided store.
4581+
// TODO: Transform reverse access into strided access with -1 stride.
4582+
// TODO: Transform gather/scatter with uniform address into strided access
4583+
// with 0 stride.
4584+
// TODO: Transform interleave access into multiple strided accesses.
4585+
if (!MemR || !isa<VPWidenLoadRecipe>(MemR) || MemR->isConsecutive())
4586+
continue;
4587+
4588+
auto *Ptr = dyn_cast<VPWidenGEPRecipe>(MemR->getAddr());
4589+
if (!Ptr)
4590+
continue;
4591+
4592+
// Memory cost model requires the pointer operand of memory access
4593+
// instruction.
4594+
Value *PtrUV = Ptr->getUnderlyingValue();
4595+
if (!PtrUV)
4596+
continue;
4597+
4598+
// Try to get base and stride here.
4599+
VPValue *BasePtr, *StrideInElement;
4600+
Type *ElementTy;
4601+
auto It = StrideCache.find(Ptr);
4602+
if (It != StrideCache.end())
4603+
std::tie(BasePtr, StrideInElement, ElementTy) = It->second;
4604+
else
4605+
std::tie(BasePtr, StrideInElement, ElementTy) = StrideCache[Ptr] =
4606+
determineBaseAndStride(Ptr);
4607+
4608+
// Skip if the memory access is not a strided access.
4609+
if (!BasePtr) {
4610+
assert(!StrideInElement && !ElementTy);
4611+
continue;
4612+
}
4613+
assert(StrideInElement && ElementTy);
4614+
4615+
Instruction &Ingredient = MemR->getIngredient();
4616+
auto IsProfitable = [&](ElementCount VF) -> bool {
4617+
Type *DataTy = toVectorTy(getLoadStoreType(&Ingredient), VF);
4618+
const Align Alignment = getLoadStoreAlignment(&Ingredient);
4619+
if (!Ctx.TTI.isLegalStridedLoadStore(DataTy, Alignment))
4620+
return false;
4621+
const InstructionCost CurrentCost = MemR->computeCost(VF, Ctx);
4622+
const InstructionCost StridedLoadStoreCost =
4623+
Ctx.TTI.getStridedMemoryOpCost(Instruction::Load, DataTy, PtrUV,
4624+
MemR->isMasked(), Alignment,
4625+
Ctx.CostKind, &Ingredient);
4626+
return StridedLoadStoreCost < CurrentCost;
4627+
};
4628+
4629+
if (!LoopVectorizationPlanner::getDecisionAndClampRange(IsProfitable,
4630+
Range)) {
4631+
PossiblyDead.insert(BasePtr);
4632+
PossiblyDead.insert(StrideInElement);
4633+
continue;
4634+
}
4635+
PossiblyDead.insert(Ptr);
4636+
4637+
// Create a new vector pointer for strided access.
4638+
auto *GEP = dyn_cast<GetElementPtrInst>(PtrUV->stripPointerCasts());
4639+
auto *NewPtr = new VPVectorPointerRecipe(
4640+
BasePtr, ElementTy, StrideInElement,
4641+
GEP ? GEP->getNoWrapFlags() : GEPNoWrapFlags::none(),
4642+
Ptr->getDebugLoc());
4643+
NewPtr->insertBefore(MemR);
4644+
4645+
const DataLayout &DL = Ingredient.getDataLayout();
4646+
TypeSize TS = DL.getTypeAllocSize(ElementTy);
4647+
unsigned TypeScale = TS.getFixedValue();
4648+
VPValue *StrideInBytes = StrideInElement;
4649+
// Scale the stride by the size of the indexed type.
4650+
if (TypeScale != 1) {
4651+
VPValue *ScaleVPV = Plan.getOrAddLiveIn(ConstantInt::get(
4652+
TypeInfo.inferScalarType(StrideInElement), TypeScale));
4653+
auto *ScaledStride =
4654+
new VPInstruction(Instruction::Mul, {StrideInElement, ScaleVPV});
4655+
ScaledStride->insertBefore(MemR);
4656+
StrideInBytes = ScaledStride;
4657+
}
4658+
4659+
auto *LoadR = cast<VPWidenLoadRecipe>(MemR);
4660+
auto *StridedLoad = new VPWidenStridedLoadRecipe(
4661+
*cast<LoadInst>(&Ingredient), NewPtr, StrideInBytes, &Plan.getVF(),
4662+
LoadR->getMask(), *LoadR, LoadR->getDebugLoc());
4663+
StridedLoad->insertBefore(LoadR);
4664+
LoadR->replaceAllUsesWith(StridedLoad);
4665+
4666+
ToErase.push_back(LoadR);
4667+
}
4668+
}
4669+
4670+
// Clean up dead memory access recipes, and unused base address and stride.
4671+
for (auto *R : ToErase)
4672+
R->eraseFromParent();
4673+
for (auto *V : PossiblyDead)
4674+
recursivelyDeleteDeadRecipes(V);
4675+
}

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,12 @@ struct VPlanTransforms {
246246
&InterleaveGroups,
247247
VPRecipeBuilder &RecipeBuilder, const bool &ScalarEpilogueAllowed);
248248

249+
/// Transform widen memory recipes into strided access recipes when legal
250+
/// and profitable. Clamps \p Range to maintain consistency with widen
251+
/// decisions of \p Plan, and uses \p Ctx to evaluate the cost.
252+
static void convertToStridedAccesses(VPlan &Plan, VPCostContext &Ctx,
253+
VFRange &Range);
254+
249255
/// Remove dead recipes from \p Plan.
250256
static void removeDeadRecipes(VPlan &Plan);
251257

llvm/test/Transforms/LoopVectorize/RISCV/blocks-with-dead-instructions.ll

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ define void @multiple_blocks_with_dead_inst_multiple_successors_6(ptr %src, i1 %
315315
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i64 [[TMP1]], 1
316316
; CHECK-NEXT: br label %[[VECTOR_PH:.*]]
317317
; CHECK: [[VECTOR_PH]]:
318+
; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
319+
; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8
318320
; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 8 x i1> poison, i1 [[IC]], i64 0
319321
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 8 x i1> [[BROADCAST_SPLATINSERT]], <vscale x 8 x i1> poison, <vscale x 8 x i32> zeroinitializer
320322
; CHECK-NEXT: [[TMP8:%.*]] = xor <vscale x 8 x i1> [[BROADCAST_SPLAT]], splat (i1 true)
@@ -323,22 +325,31 @@ define void @multiple_blocks_with_dead_inst_multiple_successors_6(ptr %src, i1 %
323325
; CHECK-NEXT: [[INDUCTION:%.*]] = add <vscale x 8 x i64> zeroinitializer, [[TMP13]]
324326
; CHECK-NEXT: br label %[[VECTOR_BODY:.*]]
325327
; CHECK: [[VECTOR_BODY]]:
328+
; CHECK-NEXT: [[EVL_BASED_IV:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], %[[VECTOR_BODY]] ]
326329
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 8 x i64> [ [[INDUCTION]], %[[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], %[[VECTOR_BODY]] ]
327330
; CHECK-NEXT: [[AVL:%.*]] = phi i64 [ [[TMP2]], %[[VECTOR_PH]] ], [ [[AVL_NEXT:%.*]], %[[VECTOR_BODY]] ]
328331
; CHECK-NEXT: [[TMP27:%.*]] = call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 8, i1 true)
332+
; CHECK-NEXT: [[BROADCAST_SPLATINSERT3:%.*]] = insertelement <vscale x 8 x i32> poison, i32 [[TMP27]], i64 0
333+
; CHECK-NEXT: [[BROADCAST_SPLAT4:%.*]] = shufflevector <vscale x 8 x i32> [[BROADCAST_SPLATINSERT3]], <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer
329334
; CHECK-NEXT: [[TMP12:%.*]] = zext i32 [[TMP27]] to i64
330335
; CHECK-NEXT: [[TMP16:%.*]] = mul i64 3, [[TMP12]]
331336
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 8 x i64> poison, i64 [[TMP16]], i64 0
332337
; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 8 x i64> [[DOTSPLATINSERT]], <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
338+
; CHECK-NEXT: [[TMP18:%.*]] = call <vscale x 8 x i32> @llvm.stepvector.nxv8i32()
339+
; CHECK-NEXT: [[TMP19:%.*]] = icmp ult <vscale x 8 x i32> [[TMP18]], [[BROADCAST_SPLAT4]]
340+
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = mul i64 [[EVL_BASED_IV]], 3
341+
; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i16, ptr [[SRC]], i64 [[OFFSET_IDX]]
333342
; CHECK-NEXT: [[TMP20:%.*]] = getelementptr i16, ptr [[SRC]], <vscale x 8 x i64> [[VEC_IND]]
334-
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 8 x i16> @llvm.vp.gather.nxv8i16.nxv8p0(<vscale x 8 x ptr> align 2 [[TMP20]], <vscale x 8 x i1> splat (i1 true), i32 [[TMP27]])
343+
; CHECK-NEXT: [[TMP15:%.*]] = trunc i64 [[TMP4]] to i32
344+
; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 8 x i16> @llvm.experimental.vp.strided.load.nxv8i16.p0.i64(ptr align 2 [[TMP21]], i64 6, <vscale x 8 x i1> [[TMP19]], i32 [[TMP15]])
335345
; CHECK-NEXT: [[TMP17:%.*]] = icmp eq <vscale x 8 x i16> [[WIDE_MASKED_GATHER]], zeroinitializer
336346
; CHECK-NEXT: [[TMP14:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i1> [[TMP8]], <vscale x 8 x i1> zeroinitializer
337347
; CHECK-NEXT: [[TMP28:%.*]] = xor <vscale x 8 x i1> [[TMP17]], splat (i1 true)
338348
; CHECK-NEXT: [[TMP22:%.*]] = or <vscale x 8 x i1> [[TMP14]], [[TMP28]]
339349
; CHECK-NEXT: [[TMP23:%.*]] = select <vscale x 8 x i1> [[TMP17]], <vscale x 8 x i1> [[BROADCAST_SPLAT]], <vscale x 8 x i1> zeroinitializer
340350
; CHECK-NEXT: [[TMP24:%.*]] = or <vscale x 8 x i1> [[TMP22]], [[TMP23]]
341351
; CHECK-NEXT: call void @llvm.vp.scatter.nxv8i16.nxv8p0(<vscale x 8 x i16> zeroinitializer, <vscale x 8 x ptr> align 2 [[TMP20]], <vscale x 8 x i1> [[TMP24]], i32 [[TMP27]])
352+
; CHECK-NEXT: [[INDEX_EVL_NEXT]] = add i64 [[TMP12]], [[EVL_BASED_IV]]
342353
; CHECK-NEXT: [[AVL_NEXT]] = sub nuw i64 [[AVL]], [[TMP12]]
343354
; CHECK-NEXT: [[VEC_IND_NEXT]] = add <vscale x 8 x i64> [[VEC_IND]], [[DOTSPLAT]]
344355
; CHECK-NEXT: [[TMP26:%.*]] = icmp eq i64 [[AVL_NEXT]], 0

0 commit comments

Comments
 (0)