@@ -7606,6 +7606,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
76067606 }
76077607 continue ;
76087608 }
7609+ // The VPlan-based cost model is more accurate for partial reduction and
7610+ // comparing against the legacy cost isn't desirable.
7611+ if (isa<VPPartialReductionRecipe>(&R))
7612+ return true ;
76097613 if (Instruction *UI = GetInstructionForCost (&R))
76107614 SeenInstrs.insert (UI);
76117615 }
@@ -8828,6 +8832,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
88288832 return Recipe;
88298833}
88308834
8835+ // / Find all possible partial reductions in the loop and track all of those that
8836+ // / are valid so recipes can be formed later.
8837+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8838+ // Find all possible partial reductions.
8839+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8840+ PartialReductionChains;
8841+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8842+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8843+ getScaledReduction (Phi, RdxDesc, Range))
8844+ PartialReductionChains.push_back (*Pair);
8845+
8846+ // A partial reduction is invalid if any of its extends are used by
8847+ // something that isn't another partial reduction. This is because the
8848+ // extends are intended to be lowered along with the reduction itself.
8849+
8850+ // Build up a set of partial reduction bin ops for efficient use checking.
8851+ SmallSet<User *, 4 > PartialReductionBinOps;
8852+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8853+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8854+
8855+ auto ExtendIsOnlyUsedByPartialReductions =
8856+ [&PartialReductionBinOps](Instruction *Extend) {
8857+ return all_of (Extend->users (), [&](const User *U) {
8858+ return PartialReductionBinOps.contains (U);
8859+ });
8860+ };
8861+
8862+ // Check if each use of a chain's two extends is a partial reduction
8863+ // and only add those that don't have non-partial reduction users.
8864+ for (auto Pair : PartialReductionChains) {
8865+ PartialReductionChain Chain = Pair.first ;
8866+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8867+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8868+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8869+ }
8870+ }
8871+
8872+ std::optional<std::pair<PartialReductionChain, unsigned >>
8873+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8874+ const RecurrenceDescriptor &Rdx,
8875+ VFRange &Range) {
8876+ // TODO: Allow scaling reductions when predicating. The select at
8877+ // the end of the loop chooses between the phi value and most recent
8878+ // reduction result, both of which have different VFs to the active lane
8879+ // mask when scaling.
8880+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8881+ return std::nullopt ;
8882+
8883+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8884+ if (!Update)
8885+ return std::nullopt ;
8886+
8887+ Value *Op = Update->getOperand (0 );
8888+ if (Op == PHI)
8889+ Op = Update->getOperand (1 );
8890+
8891+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8892+ if (!BinOp || !BinOp->hasOneUse ())
8893+ return std::nullopt ;
8894+
8895+ using namespace llvm ::PatternMatch;
8896+ Value *A, *B;
8897+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8898+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8899+ return std::nullopt ;
8900+
8901+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8902+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8903+
8904+ // Check that the extends extend from the same type.
8905+ if (A->getType () != B->getType ())
8906+ return std::nullopt ;
8907+
8908+ TTI::PartialReductionExtendKind OpAExtend =
8909+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8910+ TTI::PartialReductionExtendKind OpBExtend =
8911+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8912+
8913+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8914+
8915+ unsigned TargetScaleFactor =
8916+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8917+ A->getType ()->getPrimitiveSizeInBits ());
8918+
8919+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8920+ [&](ElementCount VF) {
8921+ InstructionCost Cost = TTI->getPartialReductionCost (
8922+ Update->getOpcode (), A->getType (), PHI->getType (), VF,
8923+ OpAExtend, OpBExtend, std::make_optional (BinOp->getOpcode ()));
8924+ return Cost.isValid ();
8925+ },
8926+ Range))
8927+ return std::make_pair (Chain, TargetScaleFactor);
8928+
8929+ return std::nullopt ;
8930+ }
8931+
88318932VPRecipeBase *
88328933VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
88338934 ArrayRef<VPValue *> Operands,
@@ -8852,9 +8953,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88528953 Legal->getReductionVars ().find (Phi)->second ;
88538954 assert (RdxDesc.getRecurrenceStartValue () ==
88548955 Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8855- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8856- CM.isInLoopReduction (Phi),
8857- CM.useOrderedReductions (RdxDesc));
8956+
8957+ // If the PHI is used by a partial reduction, set the scale factor.
8958+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8959+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8960+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8961+ PhiRecipe = new VPReductionPHIRecipe (
8962+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8963+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
88588964 } else {
88598965 // TODO: Currently fixed-order recurrences are modeled as chains of
88608966 // first-order recurrences. If there are no users of the intermediate
@@ -8886,6 +8992,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88868992 if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88878993 return tryToWidenMemory (Instr, Operands, Range);
88888994
8995+ if (getScaledReductionForInstr (Instr))
8996+ return tryToCreatePartialReduction (Instr, Operands);
8997+
88898998 if (!shouldWiden (Instr, Range))
88908999 return nullptr ;
88919000
@@ -8906,6 +9015,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
89069015 return tryToWiden (Instr, Operands, VPBB);
89079016}
89089017
9018+ VPRecipeBase *
9019+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
9020+ ArrayRef<VPValue *> Operands) {
9021+ assert (Operands.size () == 2 &&
9022+ " Unexpected number of operands for partial reduction" );
9023+
9024+ VPValue *BinOp = Operands[0 ];
9025+ VPValue *Phi = Operands[1 ];
9026+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
9027+ std::swap (BinOp, Phi);
9028+
9029+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
9030+ Reduction);
9031+ }
9032+
89099033void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
89109034 ElementCount MaxVF) {
89119035 assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -9223,7 +9347,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92239347 bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92249348 addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
92259349
9226- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9350+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9351+ Builder);
92279352
92289353 // ---------------------------------------------------------------------------
92299354 // Pre-construction: record ingredients whose recipes we'll need to further
@@ -9269,6 +9394,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92699394 bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
92709395 return Legal->blockNeedsPredication (BB) || NeedsBlends;
92719396 });
9397+
9398+ RecipeBuilder.collectScaledReductions (Range);
9399+
92729400 auto *MiddleVPBB = Plan->getMiddleBlock ();
92739401 VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi ();
92749402 for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
0 commit comments