@@ -8790,12 +8790,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87908790// / are valid so recipes can be formed later.
87918791void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
87928792 // Find all possible partial reductions.
8793- SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8793+ SmallVector<std::pair<PartialReductionChain, unsigned >>
87948794 PartialReductionChains;
8795- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8796- if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8797- getScaledReduction (Phi, RdxDesc, Range))
8798- PartialReductionChains. push_back (*Pair);
8795+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8796+ if (auto SR = getScaledReduction (Phi, RdxDesc. getLoopExitInstr (), Range))
8797+ PartialReductionChains. append (*SR);
8798+ }
87998799
88008800 // A partial reduction is invalid if any of its extends are used by
88018801 // something that isn't another partial reduction. This is because the
@@ -8823,26 +8823,42 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
88238823 }
88248824}
88258825
8826- std::optional<std::pair<PartialReductionChain, unsigned >>
8827- VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8828- const RecurrenceDescriptor &Rdx ,
8826+ std::optional<SmallVector< std::pair<PartialReductionChain, unsigned > >>
8827+ VPRecipeBuilder::getScaledReduction (Instruction *PHI,
8828+ Instruction *RdxExitInstr ,
88298829 VFRange &Range) {
8830+
8831+ if (!CM.TheLoop ->contains (RdxExitInstr))
8832+ return std::nullopt ;
8833+
88308834 // TODO: Allow scaling reductions when predicating. The select at
88318835 // the end of the loop chooses between the phi value and most recent
88328836 // reduction result, both of which have different VFs to the active lane
88338837 // mask when scaling.
8834- if (CM.blockNeedsPredicationForAnyReason (Rdx. getLoopExitInstr () ->getParent ()))
8838+ if (CM.blockNeedsPredicationForAnyReason (RdxExitInstr ->getParent ()))
88358839 return std::nullopt ;
88368840
8837- auto *Update = dyn_cast<BinaryOperator>(Rdx. getLoopExitInstr () );
8841+ auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr );
88388842 if (!Update)
88398843 return std::nullopt ;
88408844
88418845 Value *Op = Update->getOperand (0 );
88428846 Value *PhiOp = Update->getOperand (1 );
8843- if (Op == PHI) {
8844- Op = Update->getOperand (1 );
8845- PhiOp = Update->getOperand (0 );
8847+ if (Op == PHI)
8848+ std::swap (Op, PhiOp);
8849+
8850+ SmallVector<std::pair<PartialReductionChain, unsigned >> Chains;
8851+
8852+ if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8853+ if (auto SR0 = getScaledReduction (PHI, OpInst, Range)) {
8854+ Chains.append (*SR0);
8855+ PHI = SR0->rbegin ()->first .Reduction ;
8856+
8857+ Op = Update->getOperand (0 );
8858+ PhiOp = Update->getOperand (1 );
8859+ if (Op == PHI)
8860+ std::swap (Op, PhiOp);
8861+ }
88468862 }
88478863 if (PhiOp != PHI)
88488864 return std::nullopt ;
@@ -8860,12 +8876,16 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88608876 Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
88618877 Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
88628878
8879+ // Check that the extends extend from the same type.
8880+ if (A->getType () != B->getType ())
8881+ return std::nullopt ;
8882+
88638883 TTI::PartialReductionExtendKind OpAExtend =
88648884 TargetTransformInfo::getPartialReductionExtendKind (ExtA);
88658885 TTI::PartialReductionExtendKind OpBExtend =
88668886 TargetTransformInfo::getPartialReductionExtendKind (ExtB);
88678887
8868- PartialReductionChain Chain (Rdx. getLoopExitInstr () , ExtA, ExtB, BinOp);
8888+ PartialReductionChain Chain (RdxExitInstr , ExtA, ExtB, BinOp);
88698889
88708890 unsigned TargetScaleFactor =
88718891 PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
@@ -8880,9 +8900,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
88808900 return Cost.isValid ();
88818901 },
88828902 Range))
8883- return std::make_pair (Chain, TargetScaleFactor);
8903+ Chains. push_back ( std::make_pair (Chain, TargetScaleFactor) );
88848904
8885- return std:: nullopt ;
8905+ return Chains ;
88868906}
88878907
88888908VPRecipeBase *
@@ -8979,7 +8999,8 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
89798999
89809000 VPValue *BinOp = Operands[0 ];
89819001 VPValue *Phi = Operands[1 ];
8982- if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
9002+ VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
9003+ if (isa<VPReductionPHIRecipe>(BinOpRecipe) || isa<VPPartialReductionRecipe>(BinOpRecipe))
89839004 std::swap (BinOp, Phi);
89849005
89859006 return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
0 commit comments