Skip to content

Commit 4cf7702

Browse files
authored
[VPlan] Introduce replaceSymbolicStrides (NFC) (#155842)
Introduce VPlanTransforms::replaceSymbolicStrides factoring some code from LoopVectorize.
1 parent d7d8703 commit 4cf7702

File tree

3 files changed

+50
-35
lines changed

3 files changed

+50
-35
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8778,41 +8778,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
87788778
InterleaveGroups, RecipeBuilder,
87798779
CM.isScalarEpilogueAllowed());
87808780

8781-
// Replace VPValues for known constant strides guaranteed by predicate scalar
8782-
// evolution.
8783-
auto CanUseVersionedStride = [&Plan](VPUser &U, unsigned) {
8784-
auto *R = cast<VPRecipeBase>(&U);
8785-
return R->getParent()->getParent() ||
8786-
R->getParent() ==
8787-
Plan->getVectorLoopRegion()->getSinglePredecessor();
8788-
};
8789-
for (auto [_, Stride] : Legal->getLAI()->getSymbolicStrides()) {
8790-
auto *StrideV = cast<SCEVUnknown>(Stride)->getValue();
8791-
auto *ScevStride = dyn_cast<SCEVConstant>(PSE.getSCEV(StrideV));
8792-
// Only handle constant strides for now.
8793-
if (!ScevStride)
8794-
continue;
8795-
8796-
auto *CI = Plan->getOrAddLiveIn(
8797-
ConstantInt::get(Stride->getType(), ScevStride->getAPInt()));
8798-
if (VPValue *StrideVPV = Plan->getLiveIn(StrideV))
8799-
StrideVPV->replaceUsesWithIf(CI, CanUseVersionedStride);
8800-
8801-
// The versioned value may not be used in the loop directly but through a
8802-
// sext/zext. Add new live-ins in those cases.
8803-
for (Value *U : StrideV->users()) {
8804-
if (!isa<SExtInst, ZExtInst>(U))
8805-
continue;
8806-
VPValue *StrideVPV = Plan->getLiveIn(U);
8807-
if (!StrideVPV)
8808-
continue;
8809-
unsigned BW = U->getType()->getScalarSizeInBits();
8810-
APInt C = isa<SExtInst>(U) ? ScevStride->getAPInt().sext(BW)
8811-
: ScevStride->getAPInt().zext(BW);
8812-
VPValue *CI = Plan->getOrAddLiveIn(ConstantInt::get(U->getType(), C));
8813-
StrideVPV->replaceUsesWithIf(CI, CanUseVersionedStride);
8814-
}
8815-
}
8781+
// Replace VPValues for known constant strides.
8782+
VPlanTransforms::runPass(VPlanTransforms::replaceSymbolicStrides, *Plan, PSE,
8783+
Legal->getLAI()->getSymbolicStrides());
88168784

88178785
auto BlockNeedsPredication = [this](BasicBlock *BB) {
88188786
return Legal->blockNeedsPredication(BB);

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/Analysis/IVDescriptors.h"
3030
#include "llvm/Analysis/InstSimplifyFolder.h"
3131
#include "llvm/Analysis/LoopInfo.h"
32+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
3233
#include "llvm/Analysis/VectorUtils.h"
3334
#include "llvm/IR/Intrinsics.h"
3435
#include "llvm/IR/MDBuilder.h"
@@ -2539,6 +2540,46 @@ void VPlanTransforms::canonicalizeEVLLoops(VPlan &Plan) {
25392540
LatchExitingBr->eraseFromParent();
25402541
}
25412542

2543+
void VPlanTransforms::replaceSymbolicStrides(
2544+
VPlan &Plan, PredicatedScalarEvolution &PSE,
2545+
const DenseMap<Value *, const SCEV *> &StridesMap) {
2546+
// Replace VPValues for known constant strides guaranteed by predicate scalar
2547+
// evolution.
2548+
auto CanUseVersionedStride = [&Plan](VPUser &U, unsigned) {
2549+
auto *R = cast<VPRecipeBase>(&U);
2550+
return R->getParent()->getParent() ||
2551+
R->getParent() == Plan.getVectorLoopRegion()->getSinglePredecessor();
2552+
};
2553+
for (const SCEV *Stride : StridesMap.values()) {
2554+
using namespace SCEVPatternMatch;
2555+
auto *StrideV = cast<SCEVUnknown>(Stride)->getValue();
2556+
const APInt *StrideConst;
2557+
if (!match(PSE.getSCEV(StrideV), m_scev_APInt(StrideConst)))
2558+
// Only handle constant strides for now.
2559+
continue;
2560+
2561+
auto *CI =
2562+
Plan.getOrAddLiveIn(ConstantInt::get(Stride->getType(), *StrideConst));
2563+
if (VPValue *StrideVPV = Plan.getLiveIn(StrideV))
2564+
StrideVPV->replaceUsesWithIf(CI, CanUseVersionedStride);
2565+
2566+
// The versioned value may not be used in the loop directly but through a
2567+
// sext/zext. Add new live-ins in those cases.
2568+
for (Value *U : StrideV->users()) {
2569+
if (!isa<SExtInst, ZExtInst>(U))
2570+
continue;
2571+
VPValue *StrideVPV = Plan.getLiveIn(U);
2572+
if (!StrideVPV)
2573+
continue;
2574+
unsigned BW = U->getType()->getScalarSizeInBits();
2575+
APInt C =
2576+
isa<SExtInst>(U) ? StrideConst->sext(BW) : StrideConst->zext(BW);
2577+
VPValue *CI = Plan.getOrAddLiveIn(ConstantInt::get(U->getType(), C));
2578+
StrideVPV->replaceUsesWithIf(CI, CanUseVersionedStride);
2579+
}
2580+
}
2581+
}
2582+
25422583
void VPlanTransforms::dropPoisonGeneratingRecipes(
25432584
VPlan &Plan,
25442585
const std::function<bool(BasicBlock *)> &BlockNeedsPredication) {

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ struct VPlanTransforms {
199199
truncateToMinimalBitwidths(VPlan &Plan,
200200
const MapVector<Instruction *, uint64_t> &MinBWs);
201201

202+
/// Replace symbolic strides from \p StridesMap in \p Plan with constants when
203+
/// possible.
204+
static void
205+
replaceSymbolicStrides(VPlan &Plan, PredicatedScalarEvolution &PSE,
206+
const DenseMap<Value *, const SCEV *> &StridesMap);
207+
202208
/// Drop poison flags from recipes that may generate a poison value that is
203209
/// used after vectorization, even when their operands are not poison. Those
204210
/// recipes meet the following conditions:

0 commit comments

Comments
 (0)