@@ -8799,12 +8799,10 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87998799// / are valid so recipes can be formed later.
88008800void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
88018801 // Find all possible partial reductions.
8802- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8802+ SmallVector<std::pair<PartialReductionChain, unsigned >>
88038803 PartialReductionChains;
88048804 for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8805- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8806- getScaledReduction (Phi, RdxDesc, Range))
8807- PartialReductionChains.push_back (*Pair);
8805+ PartialReductionChains.append (getScaledReduction (Phi, RdxDesc.getLoopExitInstr (), Range));
88088806
88098807 // A partial reduction is invalid if any of its extends are used by
88108808 // something that isn't another partial reduction. This is because the
@@ -8832,48 +8830,65 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
88328830 }
88338831}
88348832
8835- std::optional <std::pair<PartialReductionChain, unsigned >>
8836- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8837- const RecurrenceDescriptor &Rdx ,
8833+ SmallVector <std::pair<PartialReductionChain, unsigned >>
8834+ VPRecipeBuilder::getScaledReduction (Instruction *PHI,
8835+ Instruction *RdxExitInstr ,
88388836 VFRange &Range) {
8837+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8838+
8839+ if (!CM.TheLoop ->contains (RdxExitInstr))
8840+ return Chains;
8841+
88398842 // TODO: Allow scaling reductions when predicating. The select at
88408843 // the end of the loop chooses between the phi value and most recent
88418844 // reduction result, both of which have different VFs to the active lane
88428845 // mask when scaling.
8843- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8844- return std:: nullopt ;
8846+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
8847+ return Chains ;
88458848
8846- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8849+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
88478850 if (!Update)
8848- return std:: nullopt ;
8851+ return Chains ;
88498852
88508853 Value *Op = Update->getOperand (0 );
88518854 if (Op == PHI)
88528855 Op = Update->getOperand (1 );
88538856
8857+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8858+ auto SR0 = getScaledReduction (PHI, OpInst, Range);
8859+ if (!SR0.empty ()) {
8860+ Chains.append (SR0);
8861+ PHI = SR0.rbegin ()->first .Reduction ;
8862+
8863+ Op = Update->getOperand (0 );
8864+ if (Op == PHI)
8865+ Op = Update->getOperand (1 );
8866+ }
8867+ }
8868+
88548869 auto *BinOp = dyn_cast<BinaryOperator>(Op);
88558870 if (!BinOp || !BinOp->hasOneUse ())
8856- return std:: nullopt ;
8871+ return Chains ;
88578872
88588873 using namespace llvm ::PatternMatch;
88598874 Value *A, *B;
88608875 if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
88618876 !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8862- return std:: nullopt ;
8877+ return Chains ;
88638878
88648879 Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
88658880 Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
88668881
88678882 // Check that the extends extend from the same type.
88688883 if (A->getType () != B->getType ())
8869- return std:: nullopt ;
8884+ return Chains ;
88708885
88718886 TTI::PartialReductionExtendKind OpAExtend =
88728887 TargetTransformInfo::getPartialReductionExtendKind (ExtA);
88738888 TTI::PartialReductionExtendKind OpBExtend =
88748889 TargetTransformInfo::getPartialReductionExtendKind (ExtB);
88758890
8876- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8891+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
88778892
88788893 unsigned TargetScaleFactor =
88798894 PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8887,9 +8902,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88878902 return Cost.isValid ();
88888903 },
88898904 Range))
8890- return std::make_pair (Chain, TargetScaleFactor);
8905+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
88918906
8892- return std:: nullopt ;
8907+ return Chains ;
88938908}
88948909
88958910VPRecipeBase *
@@ -8986,7 +9001,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
89869001
89879002 VPValue *BinOp = Operands[0 ];
89889003 VPValue *Phi = Operands[1 ];
8989- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
9004+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
9005+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
89909006 std::swap (BinOp, Phi);
89919007
89929008 return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
0 commit comments