@@ -8684,12 +8684,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
86848684// / are valid so recipes can be formed later.
86858685void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
86868686 // Find all possible partial reductions.
8687- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8687+ SmallVector<std::pair<PartialReductionChain, unsigned >>
86888688 PartialReductionChains;
8689- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8690- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8691- getScaledReduction (Phi, RdxDesc, Range))
8692- PartialReductionChains. push_back (*Pair);
8689+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8690+ getScaledReductions (Phi, RdxDesc. getLoopExitInstr (), Range,
8691+ PartialReductionChains);
8692+ }
86938693
86948694 // A partial reduction is invalid if any of its extends are used by
86958695 // something that isn't another partial reduction. This is because the
@@ -8717,39 +8717,54 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87178717 }
87188718}
87198719
8720- std::optional<std::pair<PartialReductionChain, unsigned >>
8721- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8722- const RecurrenceDescriptor &Rdx,
8723- VFRange &Range) {
8720+ bool VPRecipeBuilder::getScaledReductions (
8721+ Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8722+ SmallVectorImpl<std::pair<PartialReductionChain, unsigned >> &Chains) {
8723+
8724+ if (!CM.TheLoop ->contains (RdxExitInstr))
8725+ return false ;
8726+
87248727 // TODO: Allow scaling reductions when predicating. The select at
87258728 // the end of the loop chooses between the phi value and most recent
87268729 // reduction result, both of which have different VFs to the active lane
87278730 // mask when scaling.
8728- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8729- return std:: nullopt ;
8731+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8732+ return false ;
87308733
8731- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8734+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
87328735 if (!Update)
8733- return std:: nullopt ;
8736+ return false ;
87348737
87358738 Value *Op = Update->getOperand (0 );
87368739 Value *PhiOp = Update->getOperand (1 );
8737- if (Op == PHI) {
8738- Op = Update->getOperand (1 );
8739- PhiOp = Update->getOperand (0 );
8740+ if (Op == PHI)
8741+ std::swap (Op, PhiOp);
8742+
8743+ // Try and get a scaled reduction from the first non-phi operand.
8744+ // If one is found, we use the discovered reduction instruction in
8745+ // place of the accumulator for costing.
8746+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8747+ if (getScaledReductions (PHI, OpInst, Range, Chains)) {
8748+ PHI = Chains.rbegin ()->first .Reduction ;
8749+
8750+ Op = Update->getOperand (0 );
8751+ PhiOp = Update->getOperand (1 );
8752+ if (Op == PHI)
8753+ std::swap (Op, PhiOp);
8754+ }
87408755 }
87418756 if (PhiOp != PHI)
8742- return std:: nullopt ;
8757+ return false ;
87438758
87448759 auto *BinOp = dyn_cast<BinaryOperator>(Op);
87458760 if (!BinOp || !BinOp->hasOneUse ())
8746- return std:: nullopt ;
8761+ return false ;
87478762
87488763 using namespace llvm ::PatternMatch;
87498764 Value *A, *B;
87508765 if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
87518766 !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8752- return std:: nullopt ;
8767+ return false ;
87538768
87548769 Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
87558770 Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
@@ -8759,7 +8774,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87598774 TTI::PartialReductionExtendKind OpBExtend =
87608775 TargetTransformInfo::getPartialReductionExtendKind (ExtB);
87618776
8762- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8777+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
87638778
87648779 unsigned TargetScaleFactor =
87658780 PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8773,10 +8788,12 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
87738788 std::make_optional (BinOp->getOpcode ()));
87748789 return Cost.isValid ();
87758790 },
8776- Range))
8777- return std::make_pair (Chain, TargetScaleFactor);
8791+ Range)) {
8792+ Chains.push_back (std::make_pair (Chain, TargetScaleFactor));
8793+ return true ;
8794+ }
87788795
8779- return std:: nullopt ;
8796+ return false ;
87808797}
87818798
87828799VPRecipeBase *
@@ -8871,12 +8888,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88718888 " Unexpected number of operands for partial reduction" );
88728889
88738890 VPValue *BinOp = Operands[0 ];
8874- VPValue *Phi = Operands[1 ];
8875- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8876- std::swap (BinOp, Phi);
8877-
8878- return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8879- Reduction);
8891+ VPValue *Accumulator = Operands[1 ];
8892+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8893+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8894+ isa<VPPartialReductionRecipe>(BinOpRecipe))
8895+ std::swap (BinOp, Accumulator);
8896+
8897+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp,
8898+ Accumulator, Reduction);
88808899}
88818900
88828901void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
0 commit comments