@@ -7605,6 +7605,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
76057605 }
76067606 continue ;
76077607 }
7608+ // The VPlan-based cost model is more accurate for partial reduction and
7609+ // comparing against the legacy cost isn't desirable.
7610+ if (isa<VPPartialReductionRecipe>(&R))
7611+ return true ;
76087612 if (Instruction *UI = GetInstructionForCost (&R))
76097613 SeenInstrs.insert (UI);
76107614 }
@@ -8827,6 +8831,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
88278831 return Recipe;
88288832}
88298833
8834+ // / Find all possible partial reductions in the loop and track all of those that
8835+ // / are valid so recipes can be formed later.
8836+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8837+ // Find all possible partial reductions.
8838+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8839+ PartialReductionChains;
8840+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8841+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8842+ getScaledReduction (Phi, RdxDesc, Range))
8843+ PartialReductionChains.push_back (*Pair);
8844+
8845+ // A partial reduction is invalid if any of its extends are used by
8846+ // something that isn't another partial reduction. This is because the
8847+ // extends are intended to be lowered along with the reduction itself.
8848+
8849+ // Build up a set of partial reduction bin ops for efficient use checking.
8850+ SmallSet<User *, 4 > PartialReductionBinOps;
8851+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8852+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8853+
8854+ auto ExtendIsOnlyUsedByPartialReductions =
8855+ [&PartialReductionBinOps](Instruction *Extend) {
8856+ return all_of (Extend->users (), [&](const User *U) {
8857+ return PartialReductionBinOps.contains (U);
8858+ });
8859+ };
8860+
8861+ // Check if each use of a chain's two extends is a partial reduction
8862+ // and only add those that don't have non-partial reduction users.
8863+ for (auto Pair : PartialReductionChains) {
8864+ PartialReductionChain Chain = Pair.first ;
8865+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8866+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8867+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8868+ }
8869+ }
8870+
8871+ std::optional<std::pair<PartialReductionChain, unsigned >>
8872+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8873+ const RecurrenceDescriptor &Rdx,
8874+ VFRange &Range) {
8875+ // TODO: Allow scaling reductions when predicating. The select at
8876+ // the end of the loop chooses between the phi value and most recent
8877+ // reduction result, both of which have different VFs to the active lane
8878+ // mask when scaling.
8879+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8880+ return std::nullopt ;
8881+
8882+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8883+ if (!Update)
8884+ return std::nullopt ;
8885+
8886+ Value *Op = Update->getOperand (0 );
8887+ if (Op == PHI)
8888+ Op = Update->getOperand (1 );
8889+
8890+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8891+ if (!BinOp || !BinOp->hasOneUse ())
8892+ return std::nullopt ;
8893+
8894+ using namespace llvm ::PatternMatch;
8895+ Value *A, *B;
8896+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8897+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8898+ return std::nullopt ;
8899+
8900+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8901+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8902+
8903+ // Check that the extends extend from the same type.
8904+ if (A->getType () != B->getType ())
8905+ return std::nullopt ;
8906+
8907+ TTI::PartialReductionExtendKind OpAExtend =
8908+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8909+ TTI::PartialReductionExtendKind OpBExtend =
8910+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8911+
8912+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8913+
8914+ unsigned TargetScaleFactor =
8915+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8916+ A->getType ()->getPrimitiveSizeInBits ());
8917+
8918+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8919+ [&](ElementCount VF) {
8920+ InstructionCost Cost = TTI->getPartialReductionCost (
8921+ Update->getOpcode (), A->getType (), PHI->getType (), VF,
8922+ OpAExtend, OpBExtend, std::make_optional (BinOp->getOpcode ()));
8923+ return Cost.isValid ();
8924+ },
8925+ Range))
8926+ return std::make_pair (Chain, TargetScaleFactor);
8927+
8928+ return std::nullopt ;
8929+ }
8930+
88308931VPRecipeBase *
88318932VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
88328933 ArrayRef<VPValue *> Operands,
@@ -8851,9 +8952,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88518952 Legal->getReductionVars ().find (Phi)->second ;
88528953 assert (RdxDesc.getRecurrenceStartValue () ==
88538954 Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8854- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8855- CM.isInLoopReduction (Phi),
8856- CM.useOrderedReductions (RdxDesc));
8955+
8956+ // If the PHI is used by a partial reduction, set the scale factor.
8957+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8958+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8959+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8960+ PhiRecipe = new VPReductionPHIRecipe (
8961+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8962+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
88578963 } else {
88588964 // TODO: Currently fixed-order recurrences are modeled as chains of
88598965 // first-order recurrences. If there are no users of the intermediate
@@ -8885,6 +8991,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88858991 if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88868992 return tryToWidenMemory (Instr, Operands, Range);
88878993
8994+ if (getScaledReductionForInstr (Instr))
8995+ return tryToCreatePartialReduction (Instr, Operands);
8996+
88888997 if (!shouldWiden (Instr, Range))
88898998 return nullptr ;
88908999
@@ -8905,6 +9014,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
89059014 return tryToWiden (Instr, Operands, VPBB);
89069015}
89079016
9017+ VPRecipeBase *
9018+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
9019+ ArrayRef<VPValue *> Operands) {
9020+ assert (Operands.size () == 2 &&
9021+ " Unexpected number of operands for partial reduction" );
9022+
9023+ VPValue *BinOp = Operands[0 ];
9024+ VPValue *Phi = Operands[1 ];
9025+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
9026+ std::swap (BinOp, Phi);
9027+
9028+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
9029+ Reduction);
9030+ }
9031+
89089032void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
89099033 ElementCount MaxVF) {
89109034 assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -9222,7 +9346,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92229346 bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92239347 addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
92249348
9225- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9349+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9350+ Builder);
92269351
92279352 // ---------------------------------------------------------------------------
92289353 // Pre-construction: record ingredients whose recipes we'll need to further
@@ -9268,6 +9393,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92689393 bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
92699394 return Legal->blockNeedsPredication (BB) || NeedsBlends;
92709395 });
9396+
9397+ RecipeBuilder.collectScaledReductions (Range);
9398+
92719399 auto *MiddleVPBB = Plan->getMiddleBlock ();
92729400 VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi ();
92739401 for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
0 commit comments