diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 027ee21527d22..c6158d6efa505 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7258,12 +7258,30 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF, const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop); SetVector ChainOpsAndOperands(ChainOps.begin(), ChainOps.end()); + auto IsZExtOrSExt = [](const unsigned Opcode) -> bool { + return Opcode == Instruction::ZExt || Opcode == Instruction::SExt; + }; // Also include the operands of instructions in the chain, as the cost-model // may mark extends as free. + // + // For ARM, some of the instruction can folded into the reducion + // instruction. So we need to mark all folded instructions free. + // For example: We can fold reduce(mul(ext(A), ext(B))) into one + // instruction. for (auto *ChainOp : ChainOps) { for (Value *Op : ChainOp->operands()) { - if (auto *I = dyn_cast(Op)) + if (auto *I = dyn_cast(Op)) { ChainOpsAndOperands.insert(I); + if (I->getOpcode() == Instruction::Mul) { + auto *Ext0 = dyn_cast(I->getOperand(0)); + auto *Ext1 = dyn_cast(I->getOperand(1)); + if (Ext0 && IsZExtOrSExt(Ext0->getOpcode()) && Ext1 && + Ext0->getOpcode() == Ext1->getOpcode()) { + ChainOpsAndOperands.insert(Ext0); + ChainOpsAndOperands.insert(Ext1); + } + } + } } } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 68a62638b9d58..94b7ddefbbb96 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1571,6 +1571,10 @@ class VPWidenCastRecipe : public VPRecipeWithIRFlags { /// Produce widened copies of the cast. void execute(VPTransformState &State) override; + /// Return the cost of this VPWidenCastRecipe. + InstructionCost computeCost(ElementCount VF, + VPCostContext &Ctx) const override; + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 2948ecc580edc..3acd7d5e3ca4c 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -1462,6 +1462,55 @@ void VPWidenCastRecipe::execute(VPTransformState &State) { State.addMetadata(Cast, cast_or_null(getUnderlyingValue())); } +InstructionCost VPWidenCastRecipe::computeCost(ElementCount VF, + VPCostContext &Ctx) const { + // Computes the CastContextHint from a recipes that may access memory. + auto ComputeCCH = [&](const VPRecipeBase *R) -> TTI::CastContextHint { + if (VF.isScalar()) + return TTI::CastContextHint::Normal; + if (isa(R)) + return TTI::CastContextHint::Interleave; + if (const auto *ReplicateRecipe = dyn_cast(R)) + return ReplicateRecipe->isPredicated() ? TTI::CastContextHint::Masked + : TTI::CastContextHint::Normal; + const auto *WidenMemoryRecipe = dyn_cast(R); + if (WidenMemoryRecipe == nullptr) + return TTI::CastContextHint::None; + if (!WidenMemoryRecipe->isConsecutive()) + return TTI::CastContextHint::GatherScatter; + if (WidenMemoryRecipe->isReverse()) + return TTI::CastContextHint::Reversed; + if (WidenMemoryRecipe->isMasked()) + return TTI::CastContextHint::Masked; + return TTI::CastContextHint::Normal; + }; + + VPValue *Operand = getOperand(0); + TTI::CastContextHint CCH = TTI::CastContextHint::None; + // For Trunc/FPTrunc, get the context from the only user. + if ((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && + !hasMoreThanOneUniqueUser() && getNumUsers() > 0) { + if (auto *StoreRecipe = dyn_cast(*user_begin())) + CCH = ComputeCCH(StoreRecipe); + } + // For Z/Sext, get the context from the operand. + else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt || + Opcode == Instruction::FPExt) { + if (Operand->isLiveIn()) + CCH = TTI::CastContextHint::Normal; + else if (Operand->getDefiningRecipe()) + CCH = ComputeCCH(Operand->getDefiningRecipe()); + } + + auto *SrcTy = + cast(ToVectorTy(Ctx.Types.inferScalarType(Operand), VF)); + auto *DestTy = cast(ToVectorTy(getResultType(), VF)); + // Arm TTI will use the underlying instruction to determine the cost. + return Ctx.TTI.getCastInstrCost( + Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput, + dyn_cast_if_present(getUnderlyingValue())); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPWidenCastRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h index f2978b0a758b6..1900182f76e07 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -135,7 +135,7 @@ class VPValue { } /// Returns true if the value has more than one unique user. - bool hasMoreThanOneUniqueUser() { + bool hasMoreThanOneUniqueUser() const { if (getNumUsers() == 0) return false;