@@ -8682,12 +8682,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86828682// / are valid so recipes can be formed later.
86838683void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
86848684 // Find all possible partial reductions.
8685- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8685+ SmallVector<std::pair<PartialReductionChain, unsigned >>
86868686 PartialReductionChains;
8687- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8688- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8689- getScaledReduction (Phi, RdxDesc, Range))
8690- PartialReductionChains. push_back (*Pair);
8687+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8688+ if (auto SR = getScaledReduction (Phi, RdxDesc. getLoopExitInstr (), Range))
8689+ PartialReductionChains. append (*SR);
8690+ }
86918691
86928692 // A partial reduction is invalid if any of its extends are used by
86938693 // something that isn't another partial reduction. This is because the
@@ -8715,26 +8715,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87158715 }
87168716}
87178717
8718- std::optional<std::pair<PartialReductionChain, unsigned >>
8719- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8720- const RecurrenceDescriptor &Rdx,
8718+ std::optional<SmallVector<std::pair<PartialReductionChain, unsigned >>>
8719+ VPRecipeBuilder::getScaledReduction (Instruction *PHI, Instruction *RdxExitInstr,
87218720 VFRange &Range) {
8721+
8722+ if (!CM.TheLoop ->contains (RdxExitInstr))
8723+ return std::nullopt ;
8724+
87228725 // TODO: Allow scaling reductions when predicating. The select at
87238726 // the end of the loop chooses between the phi value and most recent
87248727 // reduction result, both of which have different VFs to the active lane
87258728 // mask when scaling.
8726- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8729+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
87278730 return std::nullopt ;
87288731
8729- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8732+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
87308733 if (!Update)
87318734 return std::nullopt ;
87328735
87338736 Value *Op = Update->getOperand (0 );
87348737 Value *PhiOp = Update->getOperand (1 );
8735- if (Op == PHI) {
8736- Op = Update->getOperand (1 );
8737- PhiOp = Update->getOperand (0 );
8738+ if (Op == PHI)
8739+ std::swap (Op, PhiOp);
8740+
8741+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8742+
8743+ // Try and get a scaled reduction from the first non-phi operand.
8744+ // If one is found, we use the discovered reduction instruction in
8745+ // place of the accumulator for costing.
8746+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747+ if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8748+ Chains.append (*SR0);
8749+ PHI = SR0->rbegin ()->first .Reduction ;
8750+
8751+ Op = Update->getOperand (0 );
8752+ PhiOp = Update->getOperand (1 );
8753+ if (Op == PHI)
8754+ std::swap (Op, PhiOp);
8755+ }
87388756 }
87398757 if (PhiOp != PHI)
87408758 return std::nullopt ;
@@ -8757,7 +8775,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87578775 TTI::PartialReductionExtendKind OpBExtend =
87588776 TargetTransformInfo::getPartialReductionExtendKind (ExtB);
87598777
8760- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8778+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
87618779
87628780 unsigned TargetScaleFactor =
87638781 PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8772,9 +8790,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87728790 return Cost.isValid ();
87738791 },
87748792 Range))
8775- return std::make_pair (Chain, TargetScaleFactor);
8793+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
87768794
8777- return std:: nullopt ;
8795+ return Chains ;
87788796}
87798797
87808798VPRecipeBase *
@@ -8869,12 +8887,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88698887 " Unexpected number of operands for partial reduction" );
88708888
88718889 VPValue *BinOp = Operands[0 ];
8872- VPValue *Phi = Operands[1 ];
8873- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8874- std::swap (BinOp, Phi);
8875-
8876- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8877- Reduction);
8890+ VPValue *Accumulator = Operands[1 ];
8891+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8892+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8893+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8894+ std::swap (BinOp, Accumulator);
8895+
8896+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8897+ Accumulator, Reduction);
88788898}
88798899
88808900void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
0 commit comments