@@ -8687,8 +8687,7 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
86878687 SmallVector<std::pair<PartialReductionChain, unsigned >>
86888688 PartialReductionChains;
86898689 for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8690- if (auto SR = getScaledReduction (Phi, RdxDesc.getLoopExitInstr (), Range))
8691- PartialReductionChains.append (*SR);
8690+ getScaledReductions (Phi, RdxDesc.getLoopExitInstr (), Range, PartialReductionChains);
86928691 }
86938692
86948693 // A partial reduction is invalid if any of its extends are used by
@@ -8717,38 +8716,36 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
87178716 }
87188717}
87198718
8720- std::optional<SmallVector<std::pair<PartialReductionChain, unsigned >>>
8721- VPRecipeBuilder::getScaledReduction (Instruction *PHI, Instruction *RdxExitInstr,
8722- VFRange &Range) {
8719+ bool
8720+ VPRecipeBuilder::getScaledReductions (Instruction *PHI, Instruction *RdxExitInstr,
8721+ VFRange &Range, SmallVector<std::pair<PartialReductionChain, unsigned >> &Chains ) {
87238722
87248723 if (!CM.TheLoop ->contains (RdxExitInstr))
8725- return std:: nullopt ;
8724+ return false ;
87268725
87278726 // TODO: Allow scaling reductions when predicating. The select at
87288727 // the end of the loop chooses between the phi value and most recent
87298728 // reduction result, both of which have different VFs to the active lane
87308729 // mask when scaling.
87318730 if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr->getParent ()))
8732- return std:: nullopt ;
8731+ return false ;
87338732
87348733 auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
87358734 if (!Update)
8736- return std:: nullopt ;
8735+ return false ;
87378736
87388737 Value *Op = Update->getOperand (0 );
87398738 Value *PhiOp = Update->getOperand (1 );
87408739 if (Op == PHI)
87418740 std::swap (Op, PhiOp);
87428741
8743- SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
87448742
87458743 // Try and get a scaled reduction from the first non-phi operand.
87468744 // If one is found, we use the discovered reduction instruction in
87478745 // place of the accumulator for costing.
87488746 if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8749- if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8750- Chains.append (*SR0);
8751- PHI = SR0->rbegin ()->first .Reduction ;
8747+ if (getScaledReductions (PHI, OpInst, Range, Chains)) {
8748+ PHI = Chains.rbegin ()->first .Reduction ;
87528749
87538750 Op = Update->getOperand (0 );
87548751 PhiOp = Update->getOperand (1 );
@@ -8757,17 +8754,17 @@ VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
87578754 }
87588755 }
87598756 if (PhiOp != PHI)
8760- return std:: nullopt ;
8757+ return false ;
87618758
87628759 auto *BinOp = dyn_cast<BinaryOperator>(Op);
87638760 if (!BinOp || !BinOp->hasOneUse ())
8764- return std:: nullopt ;
8761+ return false ;
87658762
87668763 using namespace llvm ::PatternMatch;
87678764 Value *A, *B;
87688765 if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
87698766 !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8770- return std:: nullopt ;
8767+ return false ;
87718768
87728769 Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
87738770 Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
@@ -8791,10 +8788,12 @@ VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
87918788 std::make_optional (BinOp->getOpcode ()));
87928789 return Cost.isValid ();
87938790 },
8794- Range))
8791+ Range)) {
87958792 Chains.push_back (std::make_pair (Chain, TargetScaleFactor));
8793+ return true ;
8794+ }
87968795
8797- return Chains ;
8796+ return false ;
87988797}
87998798
88008799VPRecipeBase *
0 commit comments