@@ -7532,6 +7532,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
75327532 }
75337533 continue ;
75347534 }
7535+ // The VPlan-based cost model is more accurate for partial reduction and
7536+ // comparing against the legacy cost isn't desirable.
7537+ if (isa<VPPartialReductionRecipe>(&R))
7538+ return true ;
75357539 if (Instruction *UI = GetInstructionForCost (&R))
75367540 SeenInstrs.insert (UI);
75377541 }
@@ -8746,6 +8750,103 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87468750 return Recipe;
87478751}
87488752
8753+ // / Find all possible partial reductions in the loop and track all of those that
8754+ // / are valid so recipes can be formed later.
8755+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8756+ // Find all possible partial reductions.
8757+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8758+ PartialReductionChains;
8759+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8760+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8761+ getScaledReduction (Phi, RdxDesc, Range))
8762+ PartialReductionChains.push_back (*Pair);
8763+
8764+ // A partial reduction is invalid if any of its extends are used by
8765+ // something that isn't another partial reduction. This is because the
8766+ // extends are intended to be lowered along with the reduction itself.
8767+
8768+ // Build up a set of partial reduction bin ops for efficient use checking.
8769+ SmallSet<User *, 4 > PartialReductionBinOps;
8770+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8771+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8772+
8773+ auto ExtendIsOnlyUsedByPartialReductions =
8774+ [&PartialReductionBinOps](Instruction *Extend) {
8775+ return all_of (Extend->users (), [&](const User *U) {
8776+ return PartialReductionBinOps.contains (U);
8777+ });
8778+ };
8779+
8780+ // Check if each use of a chain's two extends is a partial reduction
8781+ // and only add those that don't have non-partial reduction users.
8782+ for (auto Pair : PartialReductionChains) {
8783+ PartialReductionChain Chain = Pair.first ;
8784+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8785+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8786+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8787+ }
8788+ }
8789+
8790+ std::optional<std::pair<PartialReductionChain, unsigned >>
8791+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8792+ const RecurrenceDescriptor &Rdx,
8793+ VFRange &Range) {
8794+ // TODO: Allow scaling reductions when predicating. The select at
8795+ // the end of the loop chooses between the phi value and most recent
8796+ // reduction result, both of which have different VFs to the active lane
8797+ // mask when scaling.
8798+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8799+ return std::nullopt ;
8800+
8801+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8802+ if (!Update)
8803+ return std::nullopt ;
8804+
8805+ Value *Op = Update->getOperand (0 );
8806+ if (Op == PHI)
8807+ Op = Update->getOperand (1 );
8808+
8809+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8810+ if (!BinOp || !BinOp->hasOneUse ())
8811+ return std::nullopt ;
8812+
8813+ using namespace llvm ::PatternMatch;
8814+ Value *A, *B;
8815+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8816+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8817+ return std::nullopt ;
8818+
8819+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8820+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8821+
8822+ // Check that the extends extend from the same type.
8823+ if (A->getType () != B->getType ())
8824+ return std::nullopt ;
8825+
8826+ TTI::PartialReductionExtendKind OpAExtend =
8827+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8828+ TTI::PartialReductionExtendKind OpBExtend =
8829+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8830+
8831+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8832+
8833+ unsigned TargetScaleFactor =
8834+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8835+ A->getType ()->getPrimitiveSizeInBits ());
8836+
8837+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8838+ [&](ElementCount VF) {
8839+ InstructionCost Cost = TTI->getPartialReductionCost (
8840+ Update->getOpcode (), A->getType (), PHI->getType (), VF,
8841+ OpAExtend, OpBExtend, std::make_optional (BinOp->getOpcode ()));
8842+ return Cost.isValid ();
8843+ },
8844+ Range))
8845+ return std::make_pair (Chain, TargetScaleFactor);
8846+
8847+ return std::nullopt ;
8848+ }
8849+
87498850VPRecipeBase *
87508851VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
87518852 ArrayRef<VPValue *> Operands,
@@ -8770,9 +8871,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
87708871 Legal->getReductionVars ().find (Phi)->second ;
87718872 assert (RdxDesc.getRecurrenceStartValue () ==
87728873 Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8773- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8774- CM.isInLoopReduction (Phi),
8775- CM.useOrderedReductions (RdxDesc));
8874+
8875+ // If the PHI is used by a partial reduction, set the scale factor.
8876+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8877+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8878+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8879+ PhiRecipe = new VPReductionPHIRecipe (
8880+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8881+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
87768882 } else {
87778883 // TODO: Currently fixed-order recurrences are modeled as chains of
87788884 // first-order recurrences. If there are no users of the intermediate
@@ -8804,6 +8910,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88048910 if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88058911 return tryToWidenMemory (Instr, Operands, Range);
88068912
8913+ if (getScaledReductionForInstr (Instr))
8914+ return tryToCreatePartialReduction (Instr, Operands);
8915+
88078916 if (!shouldWiden (Instr, Range))
88088917 return nullptr ;
88098918
@@ -8824,6 +8933,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88248933 return tryToWiden (Instr, Operands, VPBB);
88258934}
88268935
8936+ VPRecipeBase *
8937+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
8938+ ArrayRef<VPValue *> Operands) {
8939+ assert (Operands.size () == 2 &&
8940+ " Unexpected number of operands for partial reduction" );
8941+
8942+ VPValue *BinOp = Operands[0 ];
8943+ VPValue *Phi = Operands[1 ];
8944+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8945+ std::swap (BinOp, Phi);
8946+
8947+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8948+ Reduction);
8949+ }
8950+
88278951void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
88288952 ElementCount MaxVF) {
88298953 assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -9247,7 +9371,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92479371 bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92489372 addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
92499373
9250- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9374+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9375+ Builder);
92519376
92529377 // ---------------------------------------------------------------------------
92539378 // Pre-construction: record ingredients whose recipes we'll need to further
@@ -9293,6 +9418,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92939418 bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
92949419 return Legal->blockNeedsPredication (BB) || NeedsBlends;
92959420 });
9421+
9422+ RecipeBuilder.collectScaledReductions (Range);
9423+
92969424 auto *MiddleVPBB = Plan->getMiddleBlock ();
92979425 VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi ();
92989426 for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
0 commit comments