@@ -3472,25 +3472,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
34723472 auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
34733473 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
34743474
3475- InstructionCost ExtRedCost;
34763475 InstructionCost ExtCost =
34773476 cast<VPWidenCastRecipe>(VecOp)->computeCost (VF, Ctx);
34783477 InstructionCost RedCost = Red->computeCost (VF, Ctx);
3478+ InstructionCost BaseCost = ExtCost + RedCost;
34793479
34803480 if (isa<VPPartialReductionRecipe>(Red)) {
34813481 TargetTransformInfo::PartialReductionExtendKind ExtKind =
34823482 TargetTransformInfo::getPartialReductionExtendKind (ExtOpc);
34833483 // FIXME: Move partial reduction creation, costing and clamping
34843484 // here from LoopVectorize.cpp.
3485- ExtRedCost = Ctx.TTI .getPartialReductionCost (
3486- Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
3487- llvm::TargetTransformInfo::PR_None, std::nullopt , Ctx.CostKind );
3488- } else {
3489- ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3490- Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3491- Red->getFastMathFlags (), CostKind);
3485+ InstructionCost PartialReductionCost =
3486+ Ctx.TTI .getPartialReductionCost (
3487+ Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
3488+ llvm::TargetTransformInfo::PR_None, std::nullopt ,
3489+ Ctx.CostKind );
3490+ assert (PartialReductionCost <= BaseCost &&
3491+ " A partial reduction should have a lower cost than the "
3492+ " extend + add" );
3493+ return true ;
34923494 }
3493- return ExtRedCost.isValid () && ExtRedCost < ExtCost + RedCost;
3495+
3496+ InstructionCost ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3497+ Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3498+ Red->getFastMathFlags (), CostKind);
3499+ return ExtRedCost.isValid () && ExtRedCost < BaseCost;
34943500 },
34953501 Range);
34963502 };
@@ -3535,46 +3541,50 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35353541 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35363542 Type *SrcTy =
35373543 Ext0 ? Ctx.Types .inferScalarType (Ext0->getOperand (0 )) : RedTy;
3538- InstructionCost MulAccCost;
3544+
3545+ InstructionCost MulCost = Mul->computeCost (VF, Ctx);
3546+ InstructionCost RedCost = Red->computeCost (VF, Ctx);
3547+ InstructionCost ExtCost = 0 ;
3548+ if (Ext0)
3549+ ExtCost += Ext0->computeCost (VF, Ctx);
3550+ if (Ext1)
3551+ ExtCost += Ext1->computeCost (VF, Ctx);
3552+ if (OuterExt)
3553+ ExtCost += OuterExt->computeCost (VF, Ctx);
3554+ InstructionCost BaseCost = ExtCost + MulCost + RedCost;
35393555
35403556 if (IsPartialReduction) {
35413557 Type *SrcTy2 =
35423558 Ext1 ? Ctx.Types .inferScalarType (Ext1->getOperand (0 )) : nullptr ;
35433559 // FIXME: Move partial reduction creation, costing and clamping
35443560 // here from LoopVectorize.cpp.
3545- MulAccCost = Ctx.TTI .getPartialReductionCost (
3546- Opcode, SrcTy, SrcTy2, RedTy, VF,
3547- Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
3548- Ext0->getOpcode ())
3549- : TargetTransformInfo::PR_None,
3550- Ext1 ? TargetTransformInfo::getPartialReductionExtendKind (
3551- Ext1->getOpcode ())
3552- : TargetTransformInfo::PR_None,
3553- Mul->getOpcode (), CostKind);
3554- } else {
3561+ InstructionCost PartialReductionCost =
3562+ Ctx.TTI .getPartialReductionCost (
3563+ Opcode, SrcTy, SrcTy2, RedTy, VF,
3564+ Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
3565+ Ext0->getOpcode ())
3566+ : TargetTransformInfo::PR_None,
3567+ Ext1 ? TargetTransformInfo::getPartialReductionExtendKind (
3568+ Ext1->getOpcode ())
3569+ : TargetTransformInfo::PR_None,
3570+ Mul->getOpcode (), CostKind);
3571+ assert (PartialReductionCost <= BaseCost &&
3572+ " A partial reduction should have a lower cost than the "
3573+ " extend + mul + add" );
3574+ return true ;
3575+ }
35553576 // Only partial reductions support mixed extends at the moment.
35563577 if (Ext0 && Ext1 && Ext0->getOpcode () != Ext1->getOpcode ())
35573578 return false ;
35583579
35593580 bool IsZExt =
35603581 !Ext0 || Ext0->getOpcode () == Instruction::CastOps::ZExt;
35613582 auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3562- MulAccCost = Ctx.TTI .getMulAccReductionCost (IsZExt, Opcode, RedTy,
3563- SrcVecTy, CostKind);
3564- }
3565-
3566- InstructionCost MulCost = Mul->computeCost (VF, Ctx);
3567- InstructionCost RedCost = Red->computeCost (VF, Ctx);
3568- InstructionCost ExtCost = 0 ;
3569- if (Ext0)
3570- ExtCost += Ext0->computeCost (VF, Ctx);
3571- if (Ext1)
3572- ExtCost += Ext1->computeCost (VF, Ctx);
3573- if (OuterExt)
3574- ExtCost += OuterExt->computeCost (VF, Ctx);
3583+ InstructionCost MulAccCost = Ctx.TTI .getMulAccReductionCost (
3584+ IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
35753585
3576- return MulAccCost.isValid () &&
3577- MulAccCost < ExtCost + MulCost + RedCost;
3586+ return MulAccCost.isValid () &&
3587+ MulAccCost < ExtCost + MulCost + RedCost;
35783588 },
35793589 Range);
35803590 };
0 commit comments