Skip to content

Commit 3769e1f

Browse files
committed
[VPlan] Implement transformation for widen-cast/widen-mul + reduction to abstract recipe.
This patch introduce two new recipes. * VPExtendedReductionRecipe - cast + reduction. * VPMulAccumulateReductionRecipe - (cast) + mul + reduction. This patch also implements the transformation that match following patterns via vplan and converts to abstract recipes for better cost estimation. * VPExtendedReduction - reduce(cast(...)) * VPMulAccumulateReductionRecipe - reduce.add(mul(...)) - reduce.add(mul(ext(...), ext(...)) - reduce.add(ext(mul(ext(...), ext(...)))) The conveted abstract recipes will be lower to the concrete recipes (widen-cast + widen-mul + reduction) just before recipe execution. Split from llvm#113903.
1 parent 690a30f commit 3769e1f

File tree

12 files changed

+838
-78
lines changed

12 files changed

+838
-78
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9568,10 +9568,6 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
95689568
"entry block must be set to a VPRegionBlock having a non-empty entry "
95699569
"VPBasicBlock");
95709570

9571-
for (ElementCount VF : Range)
9572-
Plan->addVF(VF);
9573-
Plan->setName("Initial VPlan");
9574-
95759571
// Update wide induction increments to use the same step as the corresponding
95769572
// wide induction. This enables detecting induction increments directly in
95779573
// VPlan and removes redundant splats.
@@ -9601,6 +9597,21 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range,
96019597
// Adjust the recipes for any inloop reductions.
96029598
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
96039599

9600+
// Transform recipes to abstract recipes if it is legal and beneficial and
9601+
// clamp the range for better cost estimation.
9602+
// TODO: Enable following transform when the EVL-version of extended-reduction
9603+
// and mulacc-reduction are implemented.
9604+
if (!CM.foldTailWithEVL()) {
9605+
VPCostContext CostCtx(CM.TTI, *CM.TLI, Legal->getWidestInductionType(), CM,
9606+
CM.CostKind);
9607+
VPlanTransforms::runPass(VPlanTransforms::convertToAbstractRecipes, *Plan,
9608+
CostCtx, Range);
9609+
}
9610+
9611+
for (ElementCount VF : Range)
9612+
Plan->addVF(VF);
9613+
Plan->setName("Initial VPlan");
9614+
96049615
// Interleave memory: for each Interleave Group we marked earlier as relevant
96059616
// for this VPlan, replace the Recipes widening its memory instructions with a
96069617
// single VPInterleaveRecipe at its insertion point.

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 252 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,8 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
517517
case VPRecipeBase::VPInstructionSC:
518518
case VPRecipeBase::VPReductionEVLSC:
519519
case VPRecipeBase::VPReductionSC:
520+
case VPRecipeBase::VPMulAccumulateReductionSC:
521+
case VPRecipeBase::VPExtendedReductionSC:
520522
case VPRecipeBase::VPReplicateSC:
521523
case VPRecipeBase::VPScalarIVStepsSC:
522524
case VPRecipeBase::VPVectorPointerSC:
@@ -601,13 +603,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
601603
DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {}
602604
};
603605

606+
struct NonNegFlagsTy {
607+
char NonNeg : 1;
608+
NonNegFlagsTy(bool IsNonNeg) : NonNeg(IsNonNeg) {}
609+
};
610+
604611
private:
605612
struct ExactFlagsTy {
606613
char IsExact : 1;
607614
};
608-
struct NonNegFlagsTy {
609-
char NonNeg : 1;
610-
};
611615
struct FastMathFlagsTy {
612616
char AllowReassoc : 1;
613617
char NoNaNs : 1;
@@ -697,6 +701,12 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
697701
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp),
698702
DisjointFlags(DisjointFlags) {}
699703

704+
template <typename IterT>
705+
VPRecipeWithIRFlags(const unsigned char SC, IterT Operands,
706+
NonNegFlagsTy NonNegFlags, DebugLoc DL = {})
707+
: VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::NonNegOp),
708+
NonNegFlags(NonNegFlags) {}
709+
700710
protected:
701711
VPRecipeWithIRFlags(const unsigned char SC, ArrayRef<VPValue *> Operands,
702712
GEPNoWrapFlags GEPFlags, DebugLoc DL = {})
@@ -715,7 +725,9 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
715725
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
716726
R->getVPDefID() == VPRecipeBase::VPReplicateSC ||
717727
R->getVPDefID() == VPRecipeBase::VPVectorEndPointerSC ||
718-
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC;
728+
R->getVPDefID() == VPRecipeBase::VPVectorPointerSC ||
729+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
730+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
719731
}
720732

721733
static inline bool classof(const VPUser *U) {
@@ -812,6 +824,15 @@ class VPRecipeWithIRFlags : public VPSingleDefRecipe {
812824

813825
FastMathFlags getFastMathFlags() const;
814826

827+
/// Returns true if the recipe has non-negative flag.
828+
bool hasNonNegFlag() const { return OpType == OperationType::NonNegOp; }
829+
830+
bool isNonNeg() const {
831+
assert(OpType == OperationType::NonNegOp &&
832+
"recipe doesn't have a NNEG flag");
833+
return NonNegFlags.NonNeg;
834+
}
835+
815836
bool hasNoUnsignedWrap() const {
816837
assert(OpType == OperationType::OverflowingBinOp &&
817838
"recipe doesn't have a NUW flag");
@@ -1289,10 +1310,21 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
12891310
: VPRecipeWithIRFlags(VPDefOpcode, Operands, I), VPIRMetadata(I),
12901311
Opcode(I.getOpcode()) {}
12911312

1313+
template <typename IterT>
1314+
VPWidenRecipe(unsigned VPDefOpcode, unsigned Opcode,
1315+
iterator_range<IterT> Operands, bool NUW, bool NSW, DebugLoc DL)
1316+
: VPRecipeWithIRFlags(VPDefOpcode, Operands, WrapFlagsTy(NUW, NSW), DL),
1317+
Opcode(Opcode) {}
1318+
12921319
public:
12931320
VPWidenRecipe(Instruction &I, ArrayRef<VPValue *> Operands)
12941321
: VPWidenRecipe(VPDef::VPWidenSC, I, Operands) {}
12951322

1323+
template <typename IterT>
1324+
VPWidenRecipe(unsigned Opcode, iterator_range<IterT> Operands, bool NUW,
1325+
bool NSW, DebugLoc DL)
1326+
: VPWidenRecipe(VPDef::VPWidenSC, Opcode, Operands, NUW, NSW, DL) {}
1327+
12961328
~VPWidenRecipe() override = default;
12971329

12981330
VPWidenRecipe *clone() override {
@@ -1337,8 +1369,15 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags, public VPIRMetadata {
13371369
"opcode of underlying cast doesn't match");
13381370
}
13391371

1340-
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy)
1341-
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op), VPIRMetadata(),
1372+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1373+
DebugLoc DL = {})
1374+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, DL), VPIRMetadata(),
1375+
Opcode(Opcode), ResultTy(ResultTy) {}
1376+
1377+
VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy,
1378+
bool IsNonNeg, DebugLoc DL = {})
1379+
: VPRecipeWithIRFlags(VPDef::VPWidenCastSC, Op, NonNegFlagsTy(IsNonNeg),
1380+
DL),
13421381
Opcode(Opcode), ResultTy(ResultTy) {}
13431382

13441383
~VPWidenCastRecipe() override = default;
@@ -2381,6 +2420,28 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23812420
setUnderlyingValue(I);
23822421
}
23832422

2423+
/// For VPExtendedReductionRecipe.
2424+
/// Note that the debug location is from the extend.
2425+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2426+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2427+
bool IsOrdered, DebugLoc DL)
2428+
: VPRecipeWithIRFlags(SC, Operands, DL), RdxKind(RdxKind),
2429+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2430+
if (CondOp)
2431+
addOperand(CondOp);
2432+
}
2433+
2434+
/// For VPMulAccumulateReductionRecipe.
2435+
/// Note that the NUW/NSW flags and the debug location are from the Mul.
2436+
VPReductionRecipe(const unsigned char SC, const RecurKind RdxKind,
2437+
ArrayRef<VPValue *> Operands, VPValue *CondOp,
2438+
bool IsOrdered, WrapFlagsTy WrapFlags, DebugLoc DL)
2439+
: VPRecipeWithIRFlags(SC, Operands, WrapFlags, DL), RdxKind(RdxKind),
2440+
IsOrdered(IsOrdered), IsConditional(CondOp) {
2441+
if (CondOp)
2442+
addOperand(CondOp);
2443+
}
2444+
23842445
public:
23852446
VPReductionRecipe(RecurKind RdxKind, FastMathFlags FMFs, Instruction *I,
23862447
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
@@ -2389,6 +2450,13 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23892450
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
23902451
IsOrdered, DL) {}
23912452

2453+
VPReductionRecipe(const RecurKind RdxKind, FastMathFlags FMFs,
2454+
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
2455+
bool IsOrdered, DebugLoc DL = {})
2456+
: VPReductionRecipe(VPDef::VPReductionSC, RdxKind, FMFs, nullptr,
2457+
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
2458+
IsOrdered, DL) {}
2459+
23922460
~VPReductionRecipe() override = default;
23932461

23942462
VPReductionRecipe *clone() override {
@@ -2399,7 +2467,9 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23992467

24002468
static inline bool classof(const VPRecipeBase *R) {
24012469
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2402-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2470+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2471+
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
2472+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
24032473
}
24042474

24052475
static inline bool classof(const VPUser *U) {
@@ -2538,6 +2608,181 @@ class VPReductionEVLRecipe : public VPReductionRecipe {
25382608
}
25392609
};
25402610

2611+
/// A recipe to represent inloop extended reduction operations, performing a
2612+
/// reduction on a extended vector operand into a scalar value, and adding the
2613+
/// result to a chain. This recipe is abstract and needs to be lowered to
2614+
/// concrete recipes before codegen. The operands are {ChainOp, VecOp,
2615+
/// [Condition]}.
2616+
class VPExtendedReductionRecipe : public VPReductionRecipe {
2617+
/// Opcode of the extend recipe will be lowered to.
2618+
Instruction::CastOps ExtOp;
2619+
2620+
Type *ResultTy;
2621+
2622+
/// For cloning VPExtendedReductionRecipe.
2623+
VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
2624+
: VPReductionRecipe(
2625+
VPDef::VPExtendedReductionSC, ExtRed->getRecurrenceKind(),
2626+
{ExtRed->getChainOp(), ExtRed->getVecOp()}, ExtRed->getCondOp(),
2627+
ExtRed->isOrdered(), ExtRed->getDebugLoc()),
2628+
ExtOp(ExtRed->getExtOpcode()), ResultTy(ExtRed->getResultType()) {
2629+
transferFlags(*ExtRed);
2630+
}
2631+
2632+
public:
2633+
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2634+
: VPReductionRecipe(VPDef::VPExtendedReductionSC, R->getRecurrenceKind(),
2635+
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
2636+
R->isOrdered(), Ext->getDebugLoc()),
2637+
ExtOp(Ext->getOpcode()), ResultTy(Ext->getResultType()) {
2638+
// Not all WidenCastRecipes contain nneg flag. Need to transfer flags from
2639+
// the original recipe to prevent setting wrong flags.
2640+
transferFlags(*Ext);
2641+
}
2642+
2643+
~VPExtendedReductionRecipe() override = default;
2644+
2645+
VPExtendedReductionRecipe *clone() override {
2646+
auto *Copy = new VPExtendedReductionRecipe(this);
2647+
Copy->transferFlags(*this);
2648+
return Copy;
2649+
}
2650+
2651+
VP_CLASSOF_IMPL(VPDef::VPExtendedReductionSC);
2652+
2653+
void execute(VPTransformState &State) override {
2654+
llvm_unreachable("VPExtendedReductionRecipe should be transform to "
2655+
"VPExtendedRecipe + VPReductionRecipe before execution.");
2656+
};
2657+
2658+
/// Return the cost of VPExtendedReductionRecipe.
2659+
InstructionCost computeCost(ElementCount VF,
2660+
VPCostContext &Ctx) const override;
2661+
2662+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2663+
/// Print the recipe.
2664+
void print(raw_ostream &O, const Twine &Indent,
2665+
VPSlotTracker &SlotTracker) const override;
2666+
#endif
2667+
2668+
/// The scalar type after extending.
2669+
Type *getResultType() const { return ResultTy; }
2670+
2671+
/// Is the extend ZExt?
2672+
bool isZExt() const { return getExtOpcode() == Instruction::ZExt; }
2673+
2674+
/// The opcode of extend recipe.
2675+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2676+
};
2677+
2678+
/// A recipe to represent inloop MulAccumulateReduction operations, performing a
2679+
/// reduction.add on the result of vector operands (might be extended)
2680+
/// multiplication into a scalar value, and adding the result to a chain. This
2681+
/// recipe is abstract and needs to be lowered to concrete recipes before
2682+
/// codegen. The operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
2683+
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2684+
/// Opcode of the extend recipe.
2685+
Instruction::CastOps ExtOp;
2686+
2687+
/// Non-neg flag of the extend recipe.
2688+
bool IsNonNeg = false;
2689+
2690+
Type *ResultTy;
2691+
2692+
/// For cloning VPMulAccumulateReductionRecipe.
2693+
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
2694+
: VPReductionRecipe(
2695+
VPDef::VPMulAccumulateReductionSC, MulAcc->getRecurrenceKind(),
2696+
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
2697+
MulAcc->getCondOp(), MulAcc->isOrdered(),
2698+
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
2699+
MulAcc->getDebugLoc()),
2700+
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2701+
ResultTy(MulAcc->getResultType()) {}
2702+
2703+
public:
2704+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
2705+
VPWidenCastRecipe *Ext0,
2706+
VPWidenCastRecipe *Ext1, Type *ResultTy)
2707+
: VPReductionRecipe(
2708+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2709+
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
2710+
R->getCondOp(), R->isOrdered(),
2711+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2712+
R->getDebugLoc()),
2713+
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
2714+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2715+
Instruction::Add &&
2716+
"The reduction instruction in MulAccumulateteReductionRecipe must "
2717+
"be Add");
2718+
// Only set the non-negative flag if the original recipe contains.
2719+
if (Ext0->hasNonNegFlag())
2720+
IsNonNeg = Ext0->isNonNeg();
2721+
}
2722+
2723+
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
2724+
: VPReductionRecipe(
2725+
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
2726+
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
2727+
R->getCondOp(), R->isOrdered(),
2728+
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2729+
R->getDebugLoc()),
2730+
ExtOp(Instruction::CastOps::CastOpsEnd) {
2731+
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
2732+
Instruction::Add &&
2733+
"The reduction instruction in MulAccumulateReductionRecipe must be "
2734+
"Add");
2735+
}
2736+
2737+
~VPMulAccumulateReductionRecipe() override = default;
2738+
2739+
VPMulAccumulateReductionRecipe *clone() override {
2740+
auto *Copy = new VPMulAccumulateReductionRecipe(this);
2741+
Copy->transferFlags(*this);
2742+
return Copy;
2743+
}
2744+
2745+
VP_CLASSOF_IMPL(VPDef::VPMulAccumulateReductionSC);
2746+
2747+
void execute(VPTransformState &State) override {
2748+
llvm_unreachable("VPMulAccumulateReductionRecipe should transform to "
2749+
"VPWidenCastRecipe + "
2750+
"VPWidenRecipe + VPReductionRecipe before execution");
2751+
}
2752+
2753+
/// Return the cost of VPMulAccumulateReductionRecipe.
2754+
InstructionCost computeCost(ElementCount VF,
2755+
VPCostContext &Ctx) const override;
2756+
2757+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
2758+
/// Print the recipe.
2759+
void print(raw_ostream &O, const Twine &Indent,
2760+
VPSlotTracker &SlotTracker) const override;
2761+
#endif
2762+
2763+
Type *getResultType() const {
2764+
assert(isExtended() && "Only support getResultType when this recipe "
2765+
"contains implicit extend.");
2766+
return ResultTy;
2767+
}
2768+
2769+
/// The VPValue of the vector value to be extended and reduced.
2770+
VPValue *getVecOp0() const { return getOperand(1); }
2771+
VPValue *getVecOp1() const { return getOperand(2); }
2772+
2773+
/// Return if this MulAcc recipe contains extended operands.
2774+
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2775+
2776+
/// Return the opcode of the extends for the operands.
2777+
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2778+
2779+
/// Return if the operands are zero extended.
2780+
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2781+
2782+
/// Return the non negative flag of the ext recipe.
2783+
bool isNonNeg() const { return IsNonNeg; }
2784+
};
2785+
25412786
/// VPReplicateRecipe replicates a given instruction producing multiple scalar
25422787
/// copies of the original scalar type, one per lane, instead of producing a
25432788
/// single copy of widened type for all lanes. If the instruction is known to be

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
273273
// TODO: Use info from interleave group.
274274
return V->getUnderlyingValue()->getType();
275275
})
276+
.Case<VPExtendedReductionRecipe, VPMulAccumulateReductionRecipe>(
277+
[](const auto *R) { return R->getResultType(); })
276278
.Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
277279
return R->getSCEV()->getType();
278280
})

0 commit comments

Comments
 (0)