@@ -8684,12 +8684,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86848684// / are valid so recipes can be formed later.
86858685void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
86868686 // Find all possible partial reductions.
8687- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8687+ SmallVector<std::pair<PartialReductionChain, unsigned >>
86888688 PartialReductionChains;
8689- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8690- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8691- getScaledReduction (Phi, RdxDesc, Range))
8692- PartialReductionChains. push_back (*Pair);
8689+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8690+ if (auto SR = getScaledReduction (Phi, RdxDesc. getLoopExitInstr (), Range))
8691+ PartialReductionChains. append (*SR);
8692+ }
86938693
86948694 // A partial reduction is invalid if any of its extends are used by
86958695 // something that isn't another partial reduction. This is because the
@@ -8717,26 +8717,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87178717 }
87188718}
87198719
8720- std::optional<std::pair<PartialReductionChain, unsigned >>
8721- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8722- const RecurrenceDescriptor &Rdx,
8720+ std::optional<SmallVector<std::pair<PartialReductionChain, unsigned >>>
8721+ VPRecipeBuilder::getScaledReduction (Instruction *PHI, Instruction *RdxExitInstr,
87238722 VFRange &Range) {
8723+
8724+ if (!CM.TheLoop ->contains (RdxExitInstr))
8725+ return std::nullopt ;
8726+
87248727 // TODO: Allow scaling reductions when predicating. The select at
87258728 // the end of the loop chooses between the phi value and most recent
87268729 // reduction result, both of which have different VFs to the active lane
87278730 // mask when scaling.
8728- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8731+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
87298732 return std::nullopt ;
87308733
8731- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8734+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
87328735 if (!Update)
87338736 return std::nullopt ;
87348737
87358738 Value *Op = Update->getOperand (0 );
87368739 Value *PhiOp = Update->getOperand (1 );
8737- if (Op == PHI) {
8738- Op = Update->getOperand (1 );
8739- PhiOp = Update->getOperand (0 );
8740+ if (Op == PHI)
8741+ std::swap (Op, PhiOp);
8742+
8743+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8744+
8745+ // Try and get a scaled reduction from the first non-phi operand.
8746+ // If one is found, we use the discovered reduction instruction in
8747+ // place of the accumulator for costing.
8748+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8749+ if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8750+ Chains.append (*SR0);
8751+ PHI = SR0->rbegin ()->first .Reduction ;
8752+
8753+ Op = Update->getOperand (0 );
8754+ PhiOp = Update->getOperand (1 );
8755+ if (Op == PHI)
8756+ std::swap (Op, PhiOp);
8757+ }
87408758 }
87418759 if (PhiOp != PHI)
87428760 return std::nullopt ;
@@ -8759,7 +8777,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87598777 TTI::PartialReductionExtendKind OpBExtend =
87608778 TargetTransformInfo::getPartialReductionExtendKind (ExtB);
87618779
8762- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8780+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
87638781
87648782 unsigned TargetScaleFactor =
87658783 PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8774,9 +8792,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87748792 return Cost.isValid ();
87758793 },
87768794 Range))
8777- return std::make_pair (Chain, TargetScaleFactor);
8795+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
87788796
8779- return std:: nullopt ;
8797+ return Chains ;
87808798}
87818799
87828800VPRecipeBase *
@@ -8871,12 +8889,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88718889 " Unexpected number of operands for partial reduction" );
88728890
88738891 VPValue *BinOp = Operands[0 ];
8874- VPValue *Phi = Operands[1 ];
8875- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8876- std::swap (BinOp, Phi);
8877-
8878- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8879- Reduction);
8892+ VPValue *Accumulator = Operands[1 ];
8893+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8894+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8895+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8896+ std::swap (BinOp, Accumulator);
8897+
8898+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8899+ Accumulator, Reduction);
88808900}
88818901
88828902void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
0 commit comments