Skip to content

Commit 324659a

Browse files
committed
[VPlan] Implment VPReductionRecipe::computeCost(). NFC
Implementation of `computeCost()` function for `VPReductionRecipe`.
1 parent cd12ffb commit 324659a

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,6 +2476,10 @@ class VPReductionRecipe : public VPSingleDefRecipe {
24762476
/// Generate the reduction in the loop
24772477
void execute(VPTransformState &State) override;
24782478

2479+
/// Return the cost of VPReductionRecipe.
2480+
InstructionCost computeCost(ElementCount VF,
2481+
VPCostContext &Ctx) const override;
2482+
24792483
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
24802484
/// Print the recipe.
24812485
void print(raw_ostream &O, const Twine &Indent,

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,30 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
20712071
State.set(this, NewRed, /*IsScalar*/ true);
20722072
}
20732073

2074+
InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2075+
VPCostContext &Ctx) const {
2076+
RecurKind RdxKind = RdxDesc.getRecurrenceKind();
2077+
Type *ElementTy = RdxDesc.getRecurrenceType();
2078+
auto *VectorTy = dyn_cast<VectorType>(ToVectorTy(ElementTy, VF));
2079+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2080+
unsigned Opcode = RdxDesc.getOpcode();
2081+
2082+
if (VectorTy == nullptr)
2083+
return InstructionCost::getInvalid();
2084+
2085+
// Cost = Reduction cost + BinOp cost
2086+
InstructionCost Cost =
2087+
Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
2088+
if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
2089+
Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
2090+
return Cost + Ctx.TTI.getMinMaxReductionCost(
2091+
Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2092+
}
2093+
2094+
return Cost + Ctx.TTI.getArithmeticReductionCost(
2095+
Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
2096+
}
2097+
20742098
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
20752099
void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
20762100
VPSlotTracker &SlotTracker) const {

0 commit comments

Comments
 (0)