Skip to content

Commit c4e3163

Browse files
committed
[VPlan] Factor out logic to common compute costs to helper (NFCI). (llvm#153361)
A number of recipes compute costs for the same opcodes for scalars or vectors, depending on the recipe. Move the common logic out to a helper in VPRecipeWithIRFlags, that is then used by VPReplicateRecipe, VPWidenRecipe and VPInstruction. This makes it easier to cover all relevant opcodes, without duplication. PR: llvm#153361 (cherry picked from commit 35be64a)
1 parent d13c9e2 commit c4e3163

File tree

2 files changed

+90
-68
lines changed

2 files changed

+90
-68
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,11 @@ struct VPRecipeWithIRFlags : public VPSingleDefRecipe, public VPIRFlags {
899899
}
900900

901901
void execute(VPTransformState &State) override = 0;
902+
903+
/// Compute the cost for this recipe for \p VF, using \p Opcode and \p Ctx.
904+
std::optional<InstructionCost>
905+
getCostForRecipeWithOpcode(unsigned Opcode, ElementCount VF,
906+
VPCostContext &Ctx) const;
902907
};
903908

904909
/// Helper to access the operand that contains the unroll part for this recipe

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 85 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -901,28 +901,90 @@ Value *VPInstruction::generate(VPTransformState &State) {
901901
}
902902
}
903903

904+
std::optional<InstructionCost> VPRecipeWithIRFlags::getCostForRecipeWithOpcode(
905+
unsigned Opcode, ElementCount VF, VPCostContext &Ctx) const {
906+
Type *ScalarTy = Ctx.Types.inferScalarType(this);
907+
Type *ResultTy = VF.isVector() ? toVectorTy(ScalarTy, VF) : ScalarTy;
908+
switch (Opcode) {
909+
case Instruction::FNeg:
910+
return Ctx.TTI.getArithmeticInstrCost(Opcode, ResultTy, Ctx.CostKind);
911+
case Instruction::UDiv:
912+
case Instruction::SDiv:
913+
case Instruction::SRem:
914+
case Instruction::URem:
915+
case Instruction::Add:
916+
case Instruction::FAdd:
917+
case Instruction::Sub:
918+
case Instruction::FSub:
919+
case Instruction::Mul:
920+
case Instruction::FMul:
921+
case Instruction::FDiv:
922+
case Instruction::FRem:
923+
case Instruction::Shl:
924+
case Instruction::LShr:
925+
case Instruction::AShr:
926+
case Instruction::And:
927+
case Instruction::Or:
928+
case Instruction::Xor: {
929+
TargetTransformInfo::OperandValueInfo RHSInfo = {
930+
TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
931+
932+
if (VF.isVector()) {
933+
// Certain instructions can be cheaper to vectorize if they have a
934+
// constant second vector operand. One example of this are shifts on x86.
935+
VPValue *RHS = getOperand(1);
936+
RHSInfo = Ctx.getOperandInfo(RHS);
937+
938+
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
939+
getOperand(1)->isDefinedOutsideLoopRegions())
940+
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
941+
}
942+
943+
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
944+
SmallVector<const Value *, 4> Operands;
945+
if (CtxI)
946+
Operands.append(CtxI->value_op_begin(), CtxI->value_op_end());
947+
return Ctx.TTI.getArithmeticInstrCost(
948+
Opcode, ResultTy, Ctx.CostKind,
949+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
950+
RHSInfo, Operands, CtxI, &Ctx.TLI);
951+
}
952+
case Instruction::Freeze:
953+
// This opcode is unknown. Assume that it is the same as 'mul'.
954+
return Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, ResultTy,
955+
Ctx.CostKind);
956+
case Instruction::ExtractValue:
957+
return Ctx.TTI.getInsertExtractValueCost(Instruction::ExtractValue,
958+
Ctx.CostKind);
959+
case Instruction::ICmp:
960+
case Instruction::FCmp: {
961+
Type *ScalarOpTy = Ctx.Types.inferScalarType(getOperand(0));
962+
Type *OpTy = VF.isVector() ? toVectorTy(ScalarOpTy, VF) : ScalarOpTy;
963+
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
964+
return Ctx.TTI.getCmpSelInstrCost(
965+
Opcode, OpTy, CmpInst::makeCmpResultType(OpTy), getPredicate(),
966+
Ctx.CostKind, {TTI::OK_AnyValue, TTI::OP_None},
967+
{TTI::OK_AnyValue, TTI::OP_None}, CtxI);
968+
}
969+
}
970+
return std::nullopt;
971+
}
972+
904973
InstructionCost VPInstruction::computeCost(ElementCount VF,
905974
VPCostContext &Ctx) const {
906975
if (Instruction::isBinaryOp(getOpcode())) {
907-
Type *ResTy = Ctx.Types.inferScalarType(this);
908-
if (!vputils::onlyFirstLaneUsed(this))
909-
ResTy = toVectorTy(ResTy, VF);
910-
911-
if (!getUnderlyingValue()) {
912-
switch (getOpcode()) {
913-
case Instruction::FMul:
914-
return Ctx.TTI.getArithmeticInstrCost(getOpcode(), ResTy, Ctx.CostKind);
915-
default:
916-
// TODO: Compute cost for VPInstructions without underlying values once
917-
// the legacy cost model has been retired.
918-
return 0;
919-
}
976+
if (!getUnderlyingValue() && getOpcode() != Instruction::FMul) {
977+
// TODO: Compute cost for VPInstructions without underlying values once
978+
// the legacy cost model has been retired.
979+
return 0;
920980
}
921981

922982
assert(!doesGeneratePerAllLanes() &&
923983
"Should only generate a vector value or single scalar, not scalars "
924984
"for all lanes.");
925-
return Ctx.TTI.getArithmeticInstrCost(getOpcode(), ResTy, Ctx.CostKind);
985+
return *getCostForRecipeWithOpcode(
986+
getOpcode(),
987+
vputils::onlyFirstLaneUsed(this) ? ElementCount::getFixed(1) : VF, Ctx);
926988
}
927989

928990
switch (getOpcode()) {
@@ -1963,20 +2025,13 @@ void VPWidenRecipe::execute(VPTransformState &State) {
19632025
InstructionCost VPWidenRecipe::computeCost(ElementCount VF,
19642026
VPCostContext &Ctx) const {
19652027
switch (Opcode) {
1966-
case Instruction::FNeg: {
1967-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
1968-
return Ctx.TTI.getArithmeticInstrCost(
1969-
Opcode, VectorTy, Ctx.CostKind,
1970-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
1971-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None});
1972-
}
1973-
19742028
case Instruction::UDiv:
19752029
case Instruction::SDiv:
19762030
case Instruction::SRem:
19772031
case Instruction::URem:
19782032
// More complex computation, let the legacy cost-model handle this for now.
19792033
return Ctx.getLegacyCost(cast<Instruction>(getUnderlyingValue()), VF);
2034+
case Instruction::FNeg:
19802035
case Instruction::Add:
19812036
case Instruction::FAdd:
19822037
case Instruction::Sub:
@@ -1990,45 +2045,12 @@ InstructionCost VPWidenRecipe::computeCost(ElementCount VF,
19902045
case Instruction::AShr:
19912046
case Instruction::And:
19922047
case Instruction::Or:
1993-
case Instruction::Xor: {
1994-
VPValue *RHS = getOperand(1);
1995-
// Certain instructions can be cheaper to vectorize if they have a constant
1996-
// second vector operand. One example of this are shifts on x86.
1997-
TargetTransformInfo::OperandValueInfo RHSInfo = Ctx.getOperandInfo(RHS);
1998-
1999-
if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
2000-
getOperand(1)->isDefinedOutsideLoopRegions())
2001-
RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
2002-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
2003-
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
2004-
2005-
SmallVector<const Value *, 4> Operands;
2006-
if (CtxI)
2007-
Operands.append(CtxI->value_op_begin(), CtxI->value_op_end());
2008-
return Ctx.TTI.getArithmeticInstrCost(
2009-
Opcode, VectorTy, Ctx.CostKind,
2010-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2011-
RHSInfo, Operands, CtxI, &Ctx.TLI);
2012-
}
2013-
case Instruction::Freeze: {
2014-
// This opcode is unknown. Assume that it is the same as 'mul'.
2015-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(this), VF);
2016-
return Ctx.TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy,
2017-
Ctx.CostKind);
2018-
}
2019-
case Instruction::ExtractValue: {
2020-
return Ctx.TTI.getInsertExtractValueCost(Instruction::ExtractValue,
2021-
Ctx.CostKind);
2022-
}
2048+
case Instruction::Xor:
2049+
case Instruction::Freeze:
2050+
case Instruction::ExtractValue:
20232051
case Instruction::ICmp:
2024-
case Instruction::FCmp: {
2025-
Instruction *CtxI = dyn_cast_or_null<Instruction>(getUnderlyingValue());
2026-
Type *VectorTy = toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF);
2027-
return Ctx.TTI.getCmpSelInstrCost(
2028-
Opcode, VectorTy, CmpInst::makeCmpResultType(VectorTy), getPredicate(),
2029-
Ctx.CostKind, {TTI::OK_AnyValue, TTI::OP_None},
2030-
{TTI::OK_AnyValue, TTI::OP_None}, CtxI);
2031-
}
2052+
case Instruction::FCmp:
2053+
return *getCostForRecipeWithOpcode(getOpcode(), VF, Ctx);
20322054
default:
20332055
llvm_unreachable("Unsupported opcode for instruction");
20342056
}
@@ -2938,8 +2960,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29382960
// transform, avoid computing their cost multiple times for now.
29392961
Ctx.SkipCostComputation.insert(UI);
29402962

2941-
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2942-
Type *ResultTy = Ctx.Types.inferScalarType(this);
29432963
switch (UI->getOpcode()) {
29442964
case Instruction::GetElementPtr:
29452965
// We mark this instruction as zero-cost because the cost of GEPs in
@@ -2963,6 +2983,7 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29632983
SmallVector<Type *, 4> Tys;
29642984
for (VPValue *ArgOp : drop_end(operands()))
29652985
Tys.push_back(Ctx.Types.inferScalarType(ArgOp));
2986+
Type *ResultTy = Ctx.Types.inferScalarType(this);
29662987
return Ctx.TTI.getCallInstrCost(CalledFn, ResultTy, Tys, Ctx.CostKind);
29672988
}
29682989
case Instruction::Add:
@@ -2979,12 +3000,8 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
29793000
case Instruction::And:
29803001
case Instruction::Or:
29813002
case Instruction::Xor: {
2982-
auto Op2Info = Ctx.getOperandInfo(getOperand(1));
2983-
SmallVector<const Value *, 4> Operands(UI->operand_values());
2984-
return Ctx.TTI.getArithmeticInstrCost(
2985-
UI->getOpcode(), ResultTy, CostKind,
2986-
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
2987-
Op2Info, Operands, UI, &Ctx.TLI) *
3003+
return *getCostForRecipeWithOpcode(getOpcode(), ElementCount::getFixed(1),
3004+
Ctx) *
29883005
(isSingleScalar() ? 1 : VF.getFixedValue());
29893006
}
29903007
}

0 commit comments

Comments
 (0)