Skip to content

Commit d85eb07

Browse files
committed
Use simple partial reduction case for reduction and extended reduction cost computation
1 parent 59545b9 commit d85eb07

File tree

8 files changed

+430
-397
lines changed

8 files changed

+430
-397
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2661,12 +2661,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
26612661
/// and needs to be lowered to concrete recipes before codegen. The operands are
26622662
/// {ChainOp, VecOp1, VecOp2, [Condition]}.
26632663
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2664-
/// Opcode of the extend for VecOp1 and VecOp2.
2665-
Instruction::CastOps ExtOp;
2666-
2667-
/// Non-neg flag of the extend recipe.
2668-
bool IsNonNeg = false;
2669-
26702664
/// The scalar type after extending.
26712665
Type *ResultTy = nullptr;
26722666

@@ -2679,8 +2673,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
26792673
MulAcc->getVFScaleFactor(),
26802674
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
26812675
MulAcc->getDebugLoc()),
2682-
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2683-
ResultTy(MulAcc->getResultType()) {
2676+
ResultTy(MulAcc->getResultType()),
2677+
VecOpInfo{MulAcc->getVecOp0Info(), MulAcc->getVecOp1Info()} {
26842678
transferFlags(*MulAcc);
26852679
setUnderlyingValue(MulAcc->getUnderlyingValue());
26862680
}
@@ -2695,18 +2689,22 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
26952689
R->getCondOp(), R->isOrdered(), R->getVFScaleFactor(),
26962690
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
26972691
R->getDebugLoc()),
2698-
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
2692+
ResultTy(ResultTy),
2693+
VecOpInfo{
2694+
{Ext0->getOpcode(), Ext0->hasNonNegFlag() && Ext0->isNonNeg()},
2695+
{Ext1->getOpcode(), Ext1->hasNonNegFlag() && Ext1->isNonNeg()}} {
26992696
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
27002697
Instruction::Add &&
27012698
"The reduction instruction in MulAccumulateteReductionRecipe must "
27022699
"be Add");
2703-
assert((ExtOp == Instruction::CastOps::ZExt ||
2704-
ExtOp == Instruction::CastOps::SExt) &&
2700+
unsigned ExtOp0 = getVecOp0Info().ExtOp;
2701+
unsigned ExtOp1 = getVecOp1Info().ExtOp;
2702+
assert((ExtOp0 == Instruction::CastOps::ZExt ||
2703+
ExtOp0 == Instruction::CastOps::SExt) &&
2704+
(ExtOp1 == Instruction::CastOps::ZExt ||
2705+
ExtOp1 == Instruction::CastOps::SExt) &&
27052706
"VPMulAccumulateReductionRecipe only supports zext and sext.");
27062707
setUnderlyingValue(R->getUnderlyingValue());
2707-
// Only set the non-negative flag if the original recipe contains.
2708-
if (Ext0->hasNonNegFlag())
2709-
IsNonNeg = Ext0->isNonNeg();
27102708
}
27112709

27122710
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2717,14 +2715,26 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27172715
R->getCondOp(), R->isOrdered(), R->getVFScaleFactor(),
27182716
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
27192717
R->getDebugLoc()),
2720-
ExtOp(Instruction::CastOps::CastOpsEnd), ResultTy(ResultTy) {
2718+
ResultTy(ResultTy) {
27212719
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
27222720
Instruction::Add &&
27232721
"The reduction instruction in MulAccumulateReductionRecipe must be "
27242722
"Add");
27252723
setUnderlyingValue(R->getUnderlyingValue());
27262724
}
27272725

2726+
struct VecOperandInfo {
2727+
/// The operand's extend opcode.
2728+
Instruction::CastOps ExtOp{Instruction::CastOps::CastOpsEnd};
2729+
/// Non-neg portion of the operand's flags.
2730+
bool IsNonNeg = false;
2731+
2732+
bool isExtended() const {
2733+
return ExtOp != Instruction::CastOps::CastOpsEnd;
2734+
}
2735+
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2736+
};
2737+
27282738
~VPMulAccumulateReductionRecipe() override = default;
27292739

27302740
VPMulAccumulateReductionRecipe *clone() override {
@@ -2758,16 +2768,15 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27582768
VPValue *getVecOp1() const { return getOperand(2); }
27592769

27602770
/// Return true if this recipe contains extended operands.
2761-
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
2762-
2763-
/// Return the opcode of the extends for the operands.
2764-
Instruction::CastOps getExtOpcode() const { return ExtOp; }
2771+
bool isExtended() const {
2772+
return getVecOp0Info().isExtended() || getVecOp1Info().isExtended();
2773+
}
27652774

2766-
/// Return if the operands are zero-extended.
2767-
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
2775+
const VecOperandInfo &getVecOp0Info() const { return VecOpInfo[0]; }
2776+
const VecOperandInfo &getVecOp1Info() const { return VecOpInfo[1]; }
27682777

2769-
/// Return true if the operand extends have the non-negative flag.
2770-
bool isNonNeg() const { return IsNonNeg; }
2778+
protected:
2779+
VecOperandInfo VecOpInfo[2];
27712780
};
27722781

27732782
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,29 +2495,10 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
24952495
std::optional<FastMathFlags> OptionalFMF =
24962496
ElementTy->isFloatingPointTy() ? std::make_optional(FMFs) : std::nullopt;
24972497

2498-
if (isPartialReduction()) {
2499-
using namespace llvm::VPlanPatternMatch;
2500-
VPValue *Mul = getVecOp();
2501-
// Some chained partial reductions used for complex numbers will have a
2502-
// negation between the mul and reduction. This extracts the mul from that
2503-
// pattern to use it for further checking.
2504-
match(Mul, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul)));
2505-
if (match(Mul,
2506-
m_Mul(m_ZExtOrSExt(m_VPValue()), m_ZExtOrSExt(m_VPValue())))) {
2507-
auto *MulR = cast<VPWidenRecipe>(Mul);
2508-
auto *Ext0R = cast<VPWidenCastRecipe>(MulR->getOperand(0));
2509-
auto *Ext1R = cast<VPWidenCastRecipe>(MulR->getOperand(1));
2510-
return Ctx.TTI.getPartialReductionCost(
2511-
Opcode, Ctx.Types.inferScalarType(Ext0R->getOperand(0)),
2512-
Ctx.Types.inferScalarType(Ext1R->getOperand(0)),
2513-
Ctx.Types.inferScalarType(getChainOp()), VF,
2514-
TargetTransformInfo::getPartialReductionExtendKind(
2515-
Ext0R->getOpcode()),
2516-
TargetTransformInfo::getPartialReductionExtendKind(
2517-
Ext1R->getOpcode()),
2518-
Instruction::Mul);
2519-
}
2520-
}
2498+
if (isPartialReduction())
2499+
return Ctx.TTI.getPartialReductionCost(
2500+
Opcode, ElementTy, ElementTy, ElementTy, VF,
2501+
TargetTransformInfo::PR_None, TargetTransformInfo::PR_None);
25212502

25222503
// TODO: Support any-of reductions.
25232504
assert(
@@ -2547,27 +2528,36 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
25472528
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
25482529
assert(RedTy->isIntegerTy() &&
25492530
"ExtendedReduction only support integer type currently.");
2531+
if (isPartialReduction())
2532+
return Ctx.TTI.getPartialReductionCost(Opcode, RedTy, SrcVecTy, SrcVecTy,
2533+
VF, TargetTransformInfo::PR_None,
2534+
TargetTransformInfo::PR_None);
25502535
return Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), RedTy, SrcVecTy,
25512536
std::nullopt, Ctx.CostKind);
25522537
}
25532538

25542539
InstructionCost
25552540
VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
25562541
VPCostContext &Ctx) const {
2542+
VecOperandInfo Ext0Info = getVecOp0Info();
2543+
VecOperandInfo Ext1Info = getVecOp1Info();
25572544
if (isPartialReduction())
25582545
return Ctx.TTI.getPartialReductionCost(
25592546
RecurrenceDescriptor::getOpcode(getRecurrenceKind()),
25602547
Ctx.Types.inferScalarType(getVecOp0()),
25612548
Ctx.Types.inferScalarType(getVecOp1()),
25622549
Ctx.Types.inferScalarType(getChainOp()), VF,
2563-
TargetTransformInfo::getPartialReductionExtendKind(getExtOpcode()),
2564-
TargetTransformInfo::getPartialReductionExtendKind(getExtOpcode()),
2550+
TargetTransformInfo::getPartialReductionExtendKind(Ext0Info.ExtOp),
2551+
TargetTransformInfo::getPartialReductionExtendKind(Ext1Info.ExtOp),
25652552
Instruction::Mul);
2553+
// Only partial reductions support mixed extends
2554+
if (Ext0Info.ExtOp != Ext1Info.ExtOp)
2555+
return InstructionCost::getInvalid(Ctx.CostKind);
25662556

25672557
Type *RedTy = Ctx.Types.inferScalarType(this);
25682558
auto *SrcVecTy =
25692559
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2570-
return Ctx.TTI.getMulAccReductionCost(isZExt(), RedTy, SrcVecTy,
2560+
return Ctx.TTI.getMulAccReductionCost(Ext0Info.isZExt(), RedTy, SrcVecTy,
25712561
Ctx.CostKind);
25722562
}
25732563

@@ -2653,18 +2643,20 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
26532643
<< " (";
26542644
O << "mul";
26552645
printFlags(O);
2646+
VecOperandInfo Ext0Info = getVecOp0Info();
2647+
VecOperandInfo Ext1Info = getVecOp1Info();
26562648
if (isExtended())
26572649
O << "(";
26582650
getVecOp0()->printAsOperand(O, SlotTracker);
2659-
if (isExtended())
2660-
O << " " << Instruction::getOpcodeName(ExtOp) << " to " << *getResultType()
2661-
<< "), (";
2651+
if (Ext0Info.isExtended())
2652+
O << " " << Instruction::getOpcodeName(Ext0Info.ExtOp) << " to "
2653+
<< *getResultType() << "), (";
26622654
else
26632655
O << ", ";
26642656
getVecOp1()->printAsOperand(O, SlotTracker);
2665-
if (isExtended())
2666-
O << " " << Instruction::getOpcodeName(ExtOp) << " to " << *getResultType()
2667-
<< ")";
2657+
if (Ext1Info.isExtended())
2658+
O << " " << Instruction::getOpcodeName(Ext1Info.ExtOp) << " to "
2659+
<< *getResultType() << ")";
26682660
if (isConditional()) {
26692661
O << ", ";
26702662
getCondOp()->printAsOperand(O, SlotTracker);

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2563,28 +2563,31 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
25632563
// reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
25642564
VPValue *Op0, *Op1;
25652565
if (MulAcc->isExtended()) {
2566+
VPMulAccumulateReductionRecipe::VecOperandInfo Ext0Info =
2567+
MulAcc->getVecOp0Info();
2568+
VPMulAccumulateReductionRecipe::VecOperandInfo Ext1Info =
2569+
MulAcc->getVecOp1Info();
25662570
Type *RedTy = MulAcc->getResultType();
2567-
if (MulAcc->isZExt())
2568-
Op0 = new VPWidenCastRecipe(
2569-
MulAcc->getExtOpcode(), MulAcc->getVecOp0(), RedTy,
2570-
VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg()), MulAcc->getDebugLoc());
2571+
if (Ext0Info.isZExt())
2572+
Op0 = new VPWidenCastRecipe(Ext0Info.ExtOp, MulAcc->getVecOp0(), RedTy,
2573+
VPIRFlags::NonNegFlagsTy(Ext0Info.IsNonNeg),
2574+
MulAcc->getDebugLoc());
25712575
else
2572-
Op0 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp0(),
2573-
RedTy, {}, MulAcc->getDebugLoc());
2576+
Op0 = new VPWidenCastRecipe(Ext0Info.ExtOp, MulAcc->getVecOp0(), RedTy,
2577+
{}, MulAcc->getDebugLoc());
25742578
Op0->getDefiningRecipe()->insertBefore(MulAcc);
25752579
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
25762580
// VPWidenCastRecipe.
25772581
if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
25782582
Op1 = Op0;
25792583
} else {
2580-
if (MulAcc->isZExt())
2581-
Op1 = new VPWidenCastRecipe(
2582-
MulAcc->getExtOpcode(), MulAcc->getVecOp1(), RedTy,
2583-
VPIRFlags::NonNegFlagsTy(MulAcc->isNonNeg()),
2584-
MulAcc->getDebugLoc());
2584+
if (Ext1Info.isZExt())
2585+
Op1 = new VPWidenCastRecipe(Ext1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
2586+
VPIRFlags::NonNegFlagsTy(Ext1Info.IsNonNeg),
2587+
MulAcc->getDebugLoc());
25852588
else
2586-
Op1 = new VPWidenCastRecipe(MulAcc->getExtOpcode(), MulAcc->getVecOp1(),
2587-
RedTy, {}, MulAcc->getDebugLoc());
2589+
Op1 = new VPWidenCastRecipe(Ext1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
2590+
{}, MulAcc->getDebugLoc());
25882591
Op1->getDefiningRecipe()->insertBefore(MulAcc);
25892592
}
25902593
} else {
@@ -2835,16 +2838,36 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
28352838

28362839
// Clamp the range if using multiply-accumulate-reduction is profitable.
28372840
auto IsMulAccValidAndClampRange =
2838-
[&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
2841+
[&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
28392842
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
28402843
return LoopVectorizationPlanner::getDecisionAndClampRange(
28412844
[&](ElementCount VF) {
28422845
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2843-
Type *SrcTy =
2846+
Type *SrcTy0 =
28442847
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
2845-
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
2846-
InstructionCost MulAccCost =
2847-
Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
2848+
Type *SrcTy1 =
2849+
Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : RedTy;
2850+
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy0, VF));
2851+
InstructionCost MulAccCost;
2852+
if (Red->isPartialReduction()) {
2853+
TargetTransformInfo::PartialReductionExtendKind Ext0Kind =
2854+
Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
2855+
Ext0->getOpcode())
2856+
: TargetTransformInfo::PR_None;
2857+
TargetTransformInfo::PartialReductionExtendKind Ext1Kind =
2858+
Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
2859+
Ext1->getOpcode())
2860+
: TargetTransformInfo::PR_None;
2861+
MulAccCost = Ctx.TTI.getPartialReductionCost(
2862+
Opcode, SrcTy0, SrcTy1, RedTy, VF, Ext0Kind, Ext1Kind,
2863+
Mul->getOpcode());
2864+
} else {
2865+
// Currently only partial reductions support mixed extension types
2866+
if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
2867+
return false;
2868+
MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, RedTy, SrcVecTy,
2869+
CostKind);
2870+
}
28482871
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
28492872
InstructionCost RedCost = Red->computeCost(VF, Ctx);
28502873
InstructionCost ExtCost = 0;
@@ -2863,6 +2886,12 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
28632886

28642887
VPValue *VecOp = Red->getVecOp();
28652888
VPValue *A, *B;
2889+
// Some chained partial reductions used for complex numbers will have a
2890+
// negation between the mul and reduction. This extracts the mul from that
2891+
// pattern to use it for further checking.
2892+
if (Red->isPartialReduction())
2893+
match(VecOp,
2894+
m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(VecOp)));
28662895
// Try to match reduce.add(mul(...)).
28672896
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
28682897
auto *RecipeA =
@@ -2872,8 +2901,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
28722901
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
28732902

28742903
// Match reduce.add(mul(ext, ext)).
2904+
// Mixed extensions are valid for partial reductions
28752905
if (RecipeA && RecipeB &&
2876-
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
2906+
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B ||
2907+
Red->isPartialReduction()) &&
28772908
match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
28782909
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
28792910
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==

0 commit comments

Comments
 (0)