@@ -8682,12 +8682,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86828682// / are valid so recipes can be formed later.
86838683void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
86848684 // Find all possible partial reductions.
8685- SmallVector<std::pair<PartialReductionChain, unsigned >>
8685+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
86868686 PartialReductionChains;
8687- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8688- if (auto SR = getScaledReduction (Phi, RdxDesc. getLoopExitInstr (), Range))
8689- PartialReductionChains. append (*SR);
8690- }
8687+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8688+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8689+ getScaledReduction (Phi, RdxDesc, Range))
8690+ PartialReductionChains. push_back (*Pair);
86918691
86928692 // A partial reduction is invalid if any of its extends are used by
86938693 // something that isn't another partial reduction. This is because the
@@ -8715,44 +8715,26 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87158715 }
87168716}
87178717
8718- std::optional<SmallVector<std::pair<PartialReductionChain, unsigned >>>
8719- VPRecipeBuilder::getScaledReduction (Instruction *PHI, Instruction *RdxExitInstr,
8718+ std::optional<std::pair<PartialReductionChain, unsigned >>
8719+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8720+ const RecurrenceDescriptor &Rdx,
87208721 VFRange &Range) {
8721-
8722- if (!CM.TheLoop ->contains (RdxExitInstr))
8723- return std::nullopt ;
8724-
87258722 // TODO: Allow scaling reductions when predicating. The select at
87268723 // the end of the loop chooses between the phi value and most recent
87278724 // reduction result, both of which have different VFs to the active lane
87288725 // mask when scaling.
8729- if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8726+ if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
87308727 return std::nullopt ;
87318728
8732- auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
8729+ auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
87338730 if (!Update)
87348731 return std::nullopt ;
87358732
87368733 Value *Op = Update->getOperand (0 );
87378734 Value *PhiOp = Update->getOperand (1 );
8738- if (Op == PHI)
8739- std::swap (Op, PhiOp);
8740-
8741- SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8742-
8743- // Try and get a scaled reduction from the first non-phi operand.
8744- // If one is found, we use the discovered reduction instruction in
8745- // place of the accumulator for costing.
8746- if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747- if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8748- Chains.append (*SR0);
8749- PHI = SR0->rbegin ()->first .Reduction ;
8750-
8751- Op = Update->getOperand (0 );
8752- PhiOp = Update->getOperand (1 );
8753- if (Op == PHI)
8754- std::swap (Op, PhiOp);
8755- }
8735+ if (Op == PHI) {
8736+ Op = Update->getOperand (1 );
8737+ PhiOp = Update->getOperand (0 );
87568738 }
87578739 if (PhiOp != PHI)
87588740 return std::nullopt ;
@@ -8775,7 +8757,7 @@ VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
87758757 TTI::PartialReductionExtendKind OpBExtend =
87768758 TargetTransformInfo::getPartialReductionExtendKind (ExtB);
87778759
8778- PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
8760+ PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
87798761
87808762 unsigned TargetScaleFactor =
87818763 PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8790,9 +8772,9 @@ VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
87908772 return Cost.isValid ();
87918773 },
87928774 Range))
8793- Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
8775+ return std::make_pair (Chain, TargetScaleFactor);
87948776
8795- return Chains ;
8777+ return std:: nullopt ;
87968778}
87978779
87988780VPRecipeBase *
@@ -8887,14 +8869,12 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88878869 " Unexpected number of operands for partial reduction" );
88888870
88898871 VPValue *BinOp = Operands[0 ];
8890- VPValue *Accumulator = Operands[1 ];
8891- VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8892- if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8893- isa<VPPartialReductionRecipe>(BinOpRecipe))
8894- std::swap (BinOp, Accumulator);
8895-
8896- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8897- Accumulator, Reduction);
8872+ VPValue *Phi = Operands[1 ];
8873+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8874+ std::swap (BinOp, Phi);
8875+
8876+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8877+ Reduction);
88988878}
88998879
89008880void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
0 commit comments