Skip to content

Commit fedea8b

Browse files
committed
Bundle partial reductions inside VPMulAccumulateReductionRecipe
1 parent b8c4eea commit fedea8b

File tree

11 files changed

+763
-864
lines changed

11 files changed

+763
-864
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ class TargetTransformInfo {
223223
/// Get the kind of extension that an instruction represents.
224224
LLVM_ABI static PartialReductionExtendKind
225225
getPartialReductionExtendKind(Instruction *I);
226+
static PartialReductionExtendKind
227+
getPartialReductionExtendKind(Instruction::CastOps ExtOpcode);
226228

227229
/// Construct a TTI object using a type implementing the \c Concept
228230
/// API below.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -995,13 +995,24 @@ InstructionCost TargetTransformInfo::getShuffleCost(
995995

996996
TargetTransformInfo::PartialReductionExtendKind
997997
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
998-
if (isa<SExtInst>(I))
999-
return PR_SignExtend;
1000-
if (isa<ZExtInst>(I))
1001-
return PR_ZeroExtend;
998+
if (auto *Cast = dyn_cast<CastInst>(I))
999+
return getPartialReductionExtendKind(Cast->getOpcode());
10021000
return PR_None;
10031001
}
10041002

1003+
TargetTransformInfo::PartialReductionExtendKind
1004+
TargetTransformInfo::getPartialReductionExtendKind(
1005+
Instruction::CastOps ExtOpcode) {
1006+
switch (ExtOpcode) {
1007+
case Instruction::CastOps::ZExt:
1008+
return PR_ZeroExtend;
1009+
case Instruction::CastOps::SExt:
1010+
return PR_SignExtend;
1011+
default:
1012+
llvm_unreachable("Unexpected cast opcode");
1013+
}
1014+
}
1015+
10051016
TTI::CastContextHint
10061017
TargetTransformInfo::getCastContextHint(const Instruction *I) {
10071018
if (!I)

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8639,9 +8639,6 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
86398639
"Expected an ADD or SUB operation for predicated partial "
86408640
"reductions (because the neutral element in the mask is zero)!");
86418641
Cond = getBlockInMask(Builder.getInsertBlock());
8642-
VPValue *Zero =
8643-
Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
8644-
BinOp = Builder.createSelect(Cond, BinOp, Zero, Reduction->getDebugLoc());
86458642
}
86468643
return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
86478644
ScaleFactor, Reduction);

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,7 +2470,8 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
24702470
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
24712471
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
24722472
R->getVPDefID() == VPRecipeBase::VPExtendedReductionSC ||
2473-
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC;
2473+
R->getVPDefID() == VPRecipeBase::VPMulAccumulateReductionSC ||
2474+
R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
24742475
}
24752476

24762477
static inline bool classof(const VPUser *U) {
@@ -2559,6 +2560,9 @@ class VPPartialReductionRecipe : public VPReductionRecipe {
25592560
/// Get the factor that the VF of this recipe's output should be scaled by.
25602561
unsigned getVFScaleFactor() const { return VFScaleFactor; }
25612562

2563+
/// Get the binary op this reduction is applied to.
2564+
VPValue *getBinOp() const { return getOperand(1); }
2565+
25622566
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
25632567
/// Print the recipe.
25642568
void print(raw_ostream &O, const Twine &Indent,
@@ -2694,6 +2698,10 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
26942698
/// The scalar type after extending.
26952699
Type *ResultTy = nullptr;
26962700

2701+
/// The scaling factor, relative to the VF, that this recipe's output is
2702+
/// divided by
2703+
unsigned VFScaleFactor = 0;
2704+
26972705
/// For cloning VPMulAccumulateReductionRecipe.
26982706
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
26992707
: VPReductionRecipe(
@@ -2703,22 +2711,25 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27032711
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
27042712
MulAcc->getDebugLoc()),
27052713
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2706-
ResultTy(MulAcc->getResultType()) {
2714+
ResultTy(MulAcc->getResultType()),
2715+
VFScaleFactor(MulAcc->getVFScaleFactor()) {
27072716
transferFlags(*MulAcc);
27082717
setUnderlyingValue(MulAcc->getUnderlyingValue());
27092718
}
27102719

27112720
public:
27122721
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
27132722
VPWidenCastRecipe *Ext0,
2714-
VPWidenCastRecipe *Ext1, Type *ResultTy)
2723+
VPWidenCastRecipe *Ext1, Type *ResultTy,
2724+
unsigned ScaleFactor = 1)
27152725
: VPReductionRecipe(
27162726
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceKind(),
27172727
{R->getChainOp(), Ext0->getOperand(0), Ext1->getOperand(0)},
27182728
R->getCondOp(), R->isOrdered(),
27192729
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
27202730
R->getDebugLoc()),
2721-
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy) {
2731+
ExtOp(Ext0->getOpcode()), ResultTy(ResultTy),
2732+
VFScaleFactor(ScaleFactor) {
27222733
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
27232734
Instruction::Add &&
27242735
"The reduction instruction in MulAccumulateteReductionRecipe must "
@@ -2791,6 +2802,10 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
27912802

27922803
/// Return true if the operand extends have the non-negative flag.
27932804
bool isNonNeg() const { return IsNonNeg; }
2805+
2806+
/// Return the scaling factor that the VF is divided by to form the recipe's
2807+
/// output
2808+
unsigned getVFScaleFactor() const { return VFScaleFactor; }
27942809
};
27952810

27962811
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
159159
case VPWidenIntrinsicSC:
160160
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
161161
case VPBlendSC:
162+
case VPPartialReductionSC:
162163
case VPReductionEVLSC:
163164
case VPReductionSC:
164165
case VPExtendedReductionSC:
@@ -295,14 +296,9 @@ InstructionCost
295296
VPPartialReductionRecipe::computeCost(ElementCount VF,
296297
VPCostContext &Ctx) const {
297298
std::optional<unsigned> Opcode = std::nullopt;
298-
VPValue *BinOp = getOperand(1);
299+
VPValue *BinOp = getBinOp();
299300

300-
// If the partial reduction is predicated, a select will be operand 0 rather
301-
// than the binary op
302301
using namespace llvm::VPlanPatternMatch;
303-
if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
304-
BinOp = BinOp->getDefiningRecipe()->getOperand(1);
305-
306302
// If BinOp is a negation, use the side effect of match to assign the actual
307303
// binary operation to BinOp
308304
match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
@@ -345,12 +341,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
345341
assert(getOpcode() == Instruction::Add &&
346342
"Unhandled partial reduction opcode");
347343

348-
Value *BinOpVal = State.get(getOperand(1));
349-
Value *PhiVal = State.get(getOperand(0));
344+
Value *BinOpVal = State.get(getBinOp());
345+
Value *PhiVal = State.get(getChainOp());
350346
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
351347

352348
Type *RetTy = PhiVal->getType();
353349

350+
/// Mask the bin op output.
351+
if (VPValue *Cond = getCondOp()) {
352+
Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
353+
BinOpVal = Builder.CreateSelect(State.get(Cond), BinOpVal, Zero);
354+
}
355+
354356
CallInst *V = Builder.CreateIntrinsic(
355357
RetTy, Intrinsic::experimental_vector_partial_reduce_add,
356358
{PhiVal, BinOpVal}, nullptr, "partial.reduce");
@@ -2570,6 +2572,14 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
25702572
InstructionCost
25712573
VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
25722574
VPCostContext &Ctx) const {
2575+
if (getVFScaleFactor() > 1) {
2576+
return Ctx.TTI.getPartialReductionCost(
2577+
Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
2578+
Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
2579+
TTI::getPartialReductionExtendKind(getExtOpcode()),
2580+
TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
2581+
}
2582+
25732583
Type *RedTy = Ctx.Types.inferScalarType(this);
25742584
auto *SrcVecTy =
25752585
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
@@ -2648,6 +2658,8 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
26482658
O << " = ";
26492659
getChainOp()->printAsOperand(O, SlotTracker);
26502660
O << " + ";
2661+
if (getVFScaleFactor() > 1)
2662+
O << "partial.";
26512663
O << "reduce."
26522664
<< Instruction::getOpcodeName(
26532665
RecurrenceDescriptor::getOpcode(getRecurrenceKind()))

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,9 +2581,15 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
25812581
MulAcc->hasNoSignedWrap(), MulAcc->getDebugLoc());
25822582
Mul->insertBefore(MulAcc);
25832583

2584-
auto *Red = new VPReductionRecipe(
2585-
MulAcc->getRecurrenceKind(), FastMathFlags(), MulAcc->getChainOp(), Mul,
2586-
MulAcc->getCondOp(), MulAcc->isOrdered(), MulAcc->getDebugLoc());
2584+
// Generate VPReductionRecipe.
2585+
VPReductionRecipe *Red = nullptr;
2586+
if (unsigned ScaleFactor = MulAcc->getVFScaleFactor(); ScaleFactor > 1)
2587+
Red = new VPPartialReductionRecipe(Instruction::Add, MulAcc->getChainOp(),
2588+
Mul, MulAcc->getCondOp(), ScaleFactor);
2589+
else
2590+
Red = new VPReductionRecipe(MulAcc->getRecurrenceKind(), FastMathFlags(),
2591+
MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
2592+
MulAcc->isOrdered(), MulAcc->getDebugLoc());
25872593
Red->insertBefore(MulAcc);
25882594

25892595
MulAcc->replaceAllUsesWith(Red);
@@ -2911,12 +2917,43 @@ static void tryToCreateAbstractReductionRecipe(VPReductionRecipe *Red,
29112917
Red->replaceAllUsesWith(AbstractR);
29122918
}
29132919

2920+
/// This function tries to create an abstract recipe from a partial reduction to
2921+
/// hide its mul and extends from cost estimation.
2922+
static void
2923+
tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
2924+
if (PRed->getOpcode() != Instruction::Add)
2925+
return;
2926+
2927+
using namespace llvm::VPlanPatternMatch;
2928+
auto *BinOp = PRed->getBinOp();
2929+
if (!match(BinOp,
2930+
m_Mul(m_ZExtOrSExt(m_VPValue()), m_ZExtOrSExt(m_VPValue()))))
2931+
return;
2932+
2933+
auto *BinOpR = cast<VPWidenRecipe>(BinOp->getDefiningRecipe());
2934+
VPWidenCastRecipe *Ext0R = dyn_cast<VPWidenCastRecipe>(BinOpR->getOperand(0));
2935+
VPWidenCastRecipe *Ext1R = dyn_cast<VPWidenCastRecipe>(BinOpR->getOperand(1));
2936+
2937+
// TODO: Make work with extends of different signedness
2938+
if (Ext0R->hasMoreThanOneUniqueUser() || Ext1R->hasMoreThanOneUniqueUser() ||
2939+
Ext0R->getOpcode() != Ext1R->getOpcode())
2940+
return;
2941+
2942+
auto *AbstractR = new VPMulAccumulateReductionRecipe(
2943+
PRed, BinOpR, Ext0R, Ext1R, Ext0R->getResultType(),
2944+
PRed->getVFScaleFactor());
2945+
AbstractR->insertBefore(PRed);
2946+
PRed->replaceAllUsesWith(AbstractR);
2947+
}
2948+
29142949
void VPlanTransforms::convertToAbstractRecipes(VPlan &Plan, VPCostContext &Ctx,
29152950
VFRange &Range) {
29162951
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
29172952
vp_depth_first_deep(Plan.getVectorLoopRegion()))) {
29182953
for (VPRecipeBase &R : *VPBB) {
2919-
if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
2954+
if (auto *PRed = dyn_cast<VPPartialReductionRecipe>(&R))
2955+
tryToCreateAbstractPartialReductionRecipe(PRed);
2956+
else if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
29202957
tryToCreateAbstractReductionRecipe(Red, Ctx, Range);
29212958
}
29222959
}

0 commit comments

Comments
 (0)