Skip to content

Commit d7d017f

Browse files
committed
Use assertion in isMulAccValidAndClampRange
1 parent 2b6bdb6 commit d7d017f

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3472,25 +3472,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
34723472
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
34733473
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
34743474

3475-
InstructionCost ExtRedCost;
34763475
InstructionCost ExtCost =
34773476
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
34783477
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3478+
InstructionCost BaseCost = ExtCost + RedCost;
34793479

34803480
if (isa<VPPartialReductionRecipe>(Red)) {
34813481
TargetTransformInfo::PartialReductionExtendKind ExtKind =
34823482
TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
34833483
// FIXME: Move partial reduction creation, costing and clamping
34843484
// here from LoopVectorize.cpp.
3485-
ExtRedCost = Ctx.TTI.getPartialReductionCost(
3486-
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
3487-
llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
3488-
} else {
3489-
ExtRedCost = Ctx.TTI.getExtendedReductionCost(
3490-
Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3491-
Red->getFastMathFlags(), CostKind);
3485+
InstructionCost PartialReductionCost =
3486+
Ctx.TTI.getPartialReductionCost(
3487+
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
3488+
llvm::TargetTransformInfo::PR_None, std::nullopt,
3489+
Ctx.CostKind);
3490+
assert(PartialReductionCost <= BaseCost &&
3491+
"A partial reduction should have a lower cost than the "
3492+
"extend + add");
3493+
return true;
34923494
}
3493-
return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
3495+
3496+
InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
3497+
Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3498+
Red->getFastMathFlags(), CostKind);
3499+
return ExtRedCost.isValid() && ExtRedCost < BaseCost;
34943500
},
34953501
Range);
34963502
};
@@ -3535,46 +3541,50 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35353541
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35363542
Type *SrcTy =
35373543
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
3538-
InstructionCost MulAccCost;
3544+
3545+
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
3546+
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3547+
InstructionCost ExtCost = 0;
3548+
if (Ext0)
3549+
ExtCost += Ext0->computeCost(VF, Ctx);
3550+
if (Ext1)
3551+
ExtCost += Ext1->computeCost(VF, Ctx);
3552+
if (OuterExt)
3553+
ExtCost += OuterExt->computeCost(VF, Ctx);
3554+
InstructionCost BaseCost = ExtCost + MulCost + RedCost;
35393555

35403556
if (IsPartialReduction) {
35413557
Type *SrcTy2 =
35423558
Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
35433559
// FIXME: Move partial reduction creation, costing and clamping
35443560
// here from LoopVectorize.cpp.
3545-
MulAccCost = Ctx.TTI.getPartialReductionCost(
3546-
Opcode, SrcTy, SrcTy2, RedTy, VF,
3547-
Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
3548-
Ext0->getOpcode())
3549-
: TargetTransformInfo::PR_None,
3550-
Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
3551-
Ext1->getOpcode())
3552-
: TargetTransformInfo::PR_None,
3553-
Mul->getOpcode(), CostKind);
3554-
} else {
3561+
InstructionCost PartialReductionCost =
3562+
Ctx.TTI.getPartialReductionCost(
3563+
Opcode, SrcTy, SrcTy2, RedTy, VF,
3564+
Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
3565+
Ext0->getOpcode())
3566+
: TargetTransformInfo::PR_None,
3567+
Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
3568+
Ext1->getOpcode())
3569+
: TargetTransformInfo::PR_None,
3570+
Mul->getOpcode(), CostKind);
3571+
assert(PartialReductionCost <= BaseCost &&
3572+
"A partial reduction should have a lower cost than the "
3573+
"extend + mul + add");
3574+
return true;
3575+
}
35553576
// Only partial reductions support mixed extends at the moment.
35563577
if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
35573578
return false;
35583579

35593580
bool IsZExt =
35603581
!Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
35613582
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
3562-
MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy,
3563-
SrcVecTy, CostKind);
3564-
}
3565-
3566-
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
3567-
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3568-
InstructionCost ExtCost = 0;
3569-
if (Ext0)
3570-
ExtCost += Ext0->computeCost(VF, Ctx);
3571-
if (Ext1)
3572-
ExtCost += Ext1->computeCost(VF, Ctx);
3573-
if (OuterExt)
3574-
ExtCost += OuterExt->computeCost(VF, Ctx);
3583+
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
3584+
IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
35753585

3576-
return MulAccCost.isValid() &&
3577-
MulAccCost < ExtCost + MulCost + RedCost;
3586+
return MulAccCost.isValid() &&
3587+
MulAccCost < ExtCost + MulCost + RedCost;
35783588
},
35793589
Range);
35803590
};

0 commit comments

Comments
 (0)