Skip to content

Commit 37127f7

Browse files
authored
[LV] Bundle sub reductions into VPExpressionRecipe (#147255)
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account. Stacked PRs: 1. #147026 2. -> #147255 3. #147302 4. #147513
1 parent de6a832 commit 37127f7

File tree

14 files changed

+542
-32
lines changed

14 files changed

+542
-32
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,12 +1647,12 @@ class TargetTransformInfo {
16471647
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
16481648

16491649
/// Calculate the cost of an extended reduction pattern, similar to
1650-
/// getArithmeticReductionCost of an Add reduction with multiply and optional
1651-
/// extensions. This is the cost of as:
1652-
/// ResTy vecreduce.add(mul (A, B)).
1653-
/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
1650+
/// getArithmeticReductionCost of an Add/Sub reduction with multiply and
1651+
/// optional extensions. This is the cost of as:
1652+
/// * ResTy vecreduce.add/sub(mul (A, B)) or,
1653+
/// * ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
16541654
LLVM_ABI InstructionCost getMulAccReductionCost(
1655-
bool IsUnsigned, Type *ResTy, VectorType *Ty,
1655+
bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
16561656
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
16571657

16581658
/// Calculate the cost of an extended reduction pattern, similar to

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -971,8 +971,8 @@ class TargetTransformInfoImplBase {
971971
}
972972

973973
virtual InstructionCost
974-
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
975-
TTI::TargetCostKind CostKind) const {
974+
getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
975+
VectorType *Ty, TTI::TargetCostKind CostKind) const {
976976
return 1;
977977
}
978978

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,14 +3260,17 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
32603260
}
32613261

32623262
InstructionCost
3263-
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
3263+
getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
3264+
VectorType *Ty,
32643265
TTI::TargetCostKind CostKind) const override {
32653266
// Without any native support, this is equivalent to the cost of
32663267
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
32673268
// vecreduce.add(mul(A, B)).
3269+
assert((RedOpcode == Instruction::Add || RedOpcode == Instruction::Sub) &&
3270+
"The reduction opcode is expected to be Add or Sub.");
32683271
VectorType *ExtTy = VectorType::get(ResTy, Ty);
32693272
InstructionCost RedCost = thisT()->getArithmeticReductionCost(
3270-
Instruction::Add, ExtTy, std::nullopt, CostKind);
3273+
RedOpcode, ExtTy, std::nullopt, CostKind);
32713274
InstructionCost ExtCost = thisT()->getCastInstrCost(
32723275
IsUnsigned ? Instruction::ZExt : Instruction::SExt, ExtTy, Ty,
32733276
TTI::CastContextHint::None, CostKind);

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,9 +1283,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
12831283
}
12841284

12851285
InstructionCost TargetTransformInfo::getMulAccReductionCost(
1286-
bool IsUnsigned, Type *ResTy, VectorType *Ty,
1286+
bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
12871287
TTI::TargetCostKind CostKind) const {
1288-
return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
1288+
return TTIImpl->getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, Ty,
1289+
CostKind);
12891290
}
12901291

12911292
InstructionCost

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5486,13 +5486,14 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
54865486
}
54875487

54885488
InstructionCost
5489-
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
5490-
VectorType *VecTy,
5489+
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
5490+
Type *ResTy, VectorType *VecTy,
54915491
TTI::TargetCostKind CostKind) const {
54925492
EVT VecVT = TLI->getValueType(DL, VecTy);
54935493
EVT ResVT = TLI->getValueType(DL, ResTy);
54945494

5495-
if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) {
5495+
if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple() &&
5496+
RedOpcode == Instruction::Add) {
54965497
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
54975498

54985499
// The legal cases with dotprod are
@@ -5503,7 +5504,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
55035504
return LT.first + 2;
55045505
}
55055506

5506-
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
5507+
return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, VecTy,
5508+
CostKind);
55075509
}
55085510

55095511
InstructionCost

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
460460
TTI::TargetCostKind CostKind) const override;
461461

462462
InstructionCost getMulAccReductionCost(
463-
bool IsUnsigned, Type *ResTy, VectorType *Ty,
463+
bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
464464
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
465465

466466
InstructionCost

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,9 +1916,11 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
19161916
}
19171917

19181918
InstructionCost
1919-
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
1920-
VectorType *ValTy,
1919+
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
1920+
Type *ResTy, VectorType *ValTy,
19211921
TTI::TargetCostKind CostKind) const {
1922+
if (RedOpcode != Instruction::Add)
1923+
return InstructionCost::getInvalid(CostKind);
19221924
EVT ValVT = TLI->getValueType(DL, ValTy);
19231925
EVT ResVT = TLI->getValueType(DL, ResTy);
19241926

@@ -1939,7 +1941,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
19391941
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
19401942
}
19411943

1942-
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
1944+
return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, ValTy,
1945+
CostKind);
19431946
}
19441947

19451948
InstructionCost

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
299299
VectorType *ValTy, std::optional<FastMathFlags> FMF,
300300
TTI::TargetCostKind CostKind) const override;
301301
InstructionCost
302-
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
302+
getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
303+
VectorType *ValTy,
303304
TTI::TargetCostKind CostKind) const override;
304305

305306
InstructionCost

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5426,7 +5426,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
54265426
TTI::CastContextHint::None, CostKind, RedOp);
54275427

54285428
InstructionCost RedCost = TTI.getMulAccReductionCost(
5429-
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
5429+
IsUnsigned, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), ExtType,
5430+
CostKind);
54305431

54315432
if (RedCost.isValid() &&
54325433
RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5471,7 +5472,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
54715472
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
54725473

54735474
InstructionCost RedCost = TTI.getMulAccReductionCost(
5474-
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
5475+
IsUnsigned, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), ExtType,
5476+
CostKind);
54755477
InstructionCost ExtraExtCost = 0;
54765478
if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
54775479
Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5490,7 +5492,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
54905492
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
54915493

54925494
InstructionCost RedCost = TTI.getMulAccReductionCost(
5493-
true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
5495+
true, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), VectorTy,
5496+
CostKind);
54945497

54955498
if (RedCost.isValid() && RedCost < MulCost + BaseCost)
54965499
return I == RetI ? RedCost : 0;

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2810,24 +2810,25 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
28102810
toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF));
28112811
assert(RedTy->isIntegerTy() &&
28122812
"VPExpressionRecipe only supports integer types currently.");
2813+
unsigned Opcode = RecurrenceDescriptor::getOpcode(
2814+
cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind());
28132815
switch (ExpressionType) {
28142816
case ExpressionTypes::ExtendedReduction: {
2815-
unsigned Opcode = RecurrenceDescriptor::getOpcode(
2816-
cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind());
28172817
return Ctx.TTI.getExtendedReductionCost(
28182818
Opcode,
28192819
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
28202820
Instruction::ZExt,
28212821
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
28222822
}
28232823
case ExpressionTypes::MulAccReduction:
2824-
return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
2824+
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
2825+
Ctx.CostKind);
28252826

28262827
case ExpressionTypes::ExtMulAccReduction:
28272828
return Ctx.TTI.getMulAccReductionCost(
28282829
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
28292830
Instruction::ZExt,
2830-
RedTy, SrcVecTy, Ctx.CostKind);
2831+
Opcode, RedTy, SrcVecTy, Ctx.CostKind);
28312832
}
28322833
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
28332834
}

0 commit comments

Comments
 (0)