@@ -3519,18 +3519,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35193519 VPValue *VecOp = Red->getVecOp ();
35203520
35213521 // Clamp the range if using extended-reduction is profitable.
3522- auto IsExtendedRedValidAndClampRange = [&]( unsigned Opcode, bool isZExt,
3523- Type *SrcTy) -> bool {
3522+ auto IsExtendedRedValidAndClampRange =
3523+ [&]( unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool {
35243524 return LoopVectorizationPlanner::getDecisionAndClampRange (
35253525 [&](ElementCount VF) {
35263526 auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
35273527 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3528- InstructionCost ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3529- Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags (),
3530- CostKind);
3528+
3529+ InstructionCost ExtRedCost;
35313530 InstructionCost ExtCost =
35323531 cast<VPWidenCastRecipe>(VecOp)->computeCost (VF, Ctx);
35333532 InstructionCost RedCost = Red->computeCost (VF, Ctx);
3533+
3534+ if (isa<VPPartialReductionRecipe>(Red)) {
3535+ TargetTransformInfo::PartialReductionExtendKind ExtKind =
3536+ TargetTransformInfo::getPartialReductionExtendKind (ExtOpc);
3537+ // FIXME: Move partial reduction creation, costing and clamping
3538+ // here from LoopVectorize.cpp.
3539+ ExtRedCost = Ctx.TTI .getPartialReductionCost (
3540+ Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
3541+ llvm::TargetTransformInfo::PR_None, std::nullopt , Ctx.CostKind );
3542+ } else {
3543+ ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3544+ Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3545+ Red->getFastMathFlags (), CostKind);
3546+ }
35343547 return ExtRedCost.isValid () && ExtRedCost < ExtCost + RedCost;
35353548 },
35363549 Range);
@@ -3541,8 +3554,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35413554 if (match (VecOp, m_ZExtOrSExt (m_VPValue (A))) &&
35423555 IsExtendedRedValidAndClampRange (
35433556 RecurrenceDescriptor::getOpcode (Red->getRecurrenceKind ()),
3544- cast<VPWidenCastRecipe>(VecOp)->getOpcode () ==
3545- Instruction::CastOps::ZExt,
3557+ cast<VPWidenCastRecipe>(VecOp)->getOpcode (),
35463558 Ctx.Types .inferScalarType (A)))
35473559 return new VPExpressionRecipe (cast<VPWidenCastRecipe>(VecOp), Red);
35483560
@@ -3560,6 +3572,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35603572static VPExpressionRecipe *
35613573tryToMatchAndCreateMulAccumulateReduction (VPReductionRecipe *Red,
35623574 VPCostContext &Ctx, VFRange &Range) {
3575+ bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
3576+
35633577 unsigned Opcode = RecurrenceDescriptor::getOpcode (Red->getRecurrenceKind ());
35643578 if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
35653579 return nullptr ;
@@ -3568,16 +3582,41 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35683582
35693583 // Clamp the range if using multiply-accumulate-reduction is profitable.
35703584 auto IsMulAccValidAndClampRange =
3571- [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
3572- VPWidenCastRecipe *Ext1, VPWidenCastRecipe * OuterExt) -> bool {
3585+ [&](VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1 ,
3586+ VPWidenCastRecipe *OuterExt) -> bool {
35733587 return LoopVectorizationPlanner::getDecisionAndClampRange (
35743588 [&](ElementCount VF) {
35753589 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35763590 Type *SrcTy =
35773591 Ext0 ? Ctx.Types .inferScalarType (Ext0->getOperand (0 )) : RedTy;
3578- auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3579- InstructionCost MulAccCost = Ctx.TTI .getMulAccReductionCost (
3580- isZExt, Opcode, RedTy, SrcVecTy, CostKind);
3592+ InstructionCost MulAccCost;
3593+
3594+ if (IsPartialReduction) {
3595+ Type *SrcTy2 =
3596+ Ext1 ? Ctx.Types .inferScalarType (Ext1->getOperand (0 )) : nullptr ;
3597+ // FIXME: Move partial reduction creation, costing and clamping
3598+ // here from LoopVectorize.cpp.
3599+ MulAccCost = Ctx.TTI .getPartialReductionCost (
3600+ Opcode, SrcTy, SrcTy2, RedTy, VF,
3601+ Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
3602+ Ext0->getOpcode ())
3603+ : TargetTransformInfo::PR_None,
3604+ Ext1 ? TargetTransformInfo::getPartialReductionExtendKind (
3605+ Ext1->getOpcode ())
3606+ : TargetTransformInfo::PR_None,
3607+ Mul->getOpcode (), CostKind);
3608+ } else {
3609+ // Only partial reductions support mixed extends at the moment.
3610+ if (Ext0 && Ext1 && Ext0->getOpcode () != Ext1->getOpcode ())
3611+ return false ;
3612+
3613+ bool IsZExt =
3614+ !Ext0 || Ext0->getOpcode () == Instruction::CastOps::ZExt;
3615+ auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3616+ MulAccCost = Ctx.TTI .getMulAccReductionCost (IsZExt, Opcode, RedTy,
3617+ SrcVecTy, CostKind);
3618+ }
3619+
35813620 InstructionCost MulCost = Mul->computeCost (VF, Ctx);
35823621 InstructionCost RedCost = Red->computeCost (VF, Ctx);
35833622 InstructionCost ExtCost = 0 ;
@@ -3611,23 +3650,18 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36113650 dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
36123651 auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe ());
36133652
3614- // Match reduce.add(mul(ext, ext)).
3615- if (RecipeA && RecipeB &&
3616- (RecipeA->getOpcode () == RecipeB->getOpcode () || A == B) &&
3617- match (RecipeA, m_ZExtOrSExt (m_VPValue ())) &&
3653+ // Match reduce.add/sub(mul(ext, ext)).
3654+ if (RecipeA && RecipeB && match (RecipeA, m_ZExtOrSExt (m_VPValue ())) &&
36183655 match (RecipeB, m_ZExtOrSExt (m_VPValue ())) &&
3619- IsMulAccValidAndClampRange (RecipeA->getOpcode () ==
3620- Instruction::CastOps::ZExt,
3621- Mul, RecipeA, RecipeB, nullptr )) {
3656+ IsMulAccValidAndClampRange (Mul, RecipeA, RecipeB, nullptr )) {
36223657 if (Sub)
36233658 return new VPExpressionRecipe (RecipeA, RecipeB, Mul,
36243659 cast<VPWidenRecipe>(Sub), Red);
36253660 return new VPExpressionRecipe (RecipeA, RecipeB, Mul, Red);
36263661 }
36273662 // Match reduce.add(mul).
36283663 // TODO: Add an expression type for this variant with a negated mul
3629- if (!Sub &&
3630- IsMulAccValidAndClampRange (true , Mul, nullptr , nullptr , nullptr ))
3664+ if (!Sub && IsMulAccValidAndClampRange (Mul, nullptr , nullptr , nullptr ))
36313665 return new VPExpressionRecipe (Mul, Red);
36323666 }
36333667 // TODO: Add an expression type for negated versions of other expression
@@ -3647,9 +3681,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36473681 cast<VPWidenCastRecipe>(Mul->getOperand (1 )->getDefiningRecipe ());
36483682 if ((Ext->getOpcode () == Ext0->getOpcode () || Ext0 == Ext1) &&
36493683 Ext0->getOpcode () == Ext1->getOpcode () &&
3650- IsMulAccValidAndClampRange (Ext0->getOpcode () ==
3651- Instruction::CastOps::ZExt,
3652- Mul, Ext0, Ext1, Ext)) {
3684+ IsMulAccValidAndClampRange (Mul, Ext0, Ext1, Ext) && Mul->hasOneUse ()) {
36533685 auto *NewExt0 = new VPWidenCastRecipe (
36543686 Ext0->getOpcode (), Ext0->getOperand (0 ), Ext->getResultType (), *Ext0,
36553687 *Ext0, Ext0->getDebugLoc ());
0 commit comments