@@ -3883,6 +3883,7 @@ class BoUpSLP {
3883
3883
enum CombinedOpcode {
3884
3884
NotCombinedOp = -1,
3885
3885
MinMax = Instruction::OtherOpsEnd + 1,
3886
+ FMulAdd,
3886
3887
};
3887
3888
CombinedOpcode CombinedOp = NotCombinedOp;
3888
3889
@@ -4033,6 +4034,9 @@ class BoUpSLP {
4033
4034
/// Returns true if any scalar in the list is a copyable element.
4034
4035
bool hasCopyableElements() const { return !CopyableElements.empty(); }
4035
4036
4037
+ /// Returns the state of the operations.
4038
+ const InstructionsState &getOperations() const { return S; }
4039
+
4036
4040
/// When ReuseReorderShuffleIndices is empty it just returns position of \p
4037
4041
/// V within vector of Scalars. Otherwise, try to remap on its reuse index.
4038
4042
unsigned findLaneForValue(Value *V) const {
@@ -11987,6 +11991,81 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
11987
11991
}
11988
11992
}
11989
11993
11994
+ static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
11995
+ const InstructionsState &S,
11996
+ DominatorTree &DT, const DataLayout &DL,
11997
+ TargetTransformInfo &TTI,
11998
+ const TargetLibraryInfo &TLI) {
11999
+ assert(all_of(VL,
12000
+ [](Value *V) {
12001
+ return V->getType()->getScalarType()->isFloatingPointTy();
12002
+ }) &&
12003
+ "Can only convert to FMA for floating point types");
12004
+ assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
12005
+
12006
+ auto CheckForContractable = [&](ArrayRef<Value *> VL) {
12007
+ FastMathFlags FMF;
12008
+ FMF.set();
12009
+ for (Value *V : VL) {
12010
+ auto *I = dyn_cast<Instruction>(V);
12011
+ if (!I)
12012
+ continue;
12013
+ // TODO: support for copyable elements.
12014
+ Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
12015
+ if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
12016
+ continue;
12017
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12018
+ FMF &= FPCI->getFastMathFlags();
12019
+ }
12020
+ return FMF.allowContract();
12021
+ };
12022
+ if (!CheckForContractable(VL))
12023
+ return InstructionCost::getInvalid();
12024
+ // fmul also should be contractable
12025
+ InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
12026
+ SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
12027
+
12028
+ InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
12029
+ if (!OpS.valid())
12030
+ return InstructionCost::getInvalid();
12031
+ if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
12032
+ return InstructionCost::getInvalid();
12033
+ if (!CheckForContractable(Operands.front()))
12034
+ return InstructionCost::getInvalid();
12035
+ // Compare the costs.
12036
+ InstructionCost FMulPlusFAddCost = 0;
12037
+ InstructionCost FMACost = 0;
12038
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
12039
+ FastMathFlags FMF;
12040
+ FMF.set();
12041
+ for (Value *V : VL) {
12042
+ auto *I = dyn_cast<Instruction>(V);
12043
+ if (!I)
12044
+ continue;
12045
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12046
+ FMF &= FPCI->getFastMathFlags();
12047
+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12048
+ }
12049
+ unsigned NumOps = 0;
12050
+ for (auto [V, Op] : zip(VL, Operands.front())) {
12051
+ auto *I = dyn_cast<Instruction>(Op);
12052
+ if (!I || !I->hasOneUse()) {
12053
+ FMACost += TTI.getInstructionCost(cast<Instruction>(V), CostKind);
12054
+ if (I)
12055
+ FMACost += TTI.getInstructionCost(I, CostKind);
12056
+ continue;
12057
+ }
12058
+ ++NumOps;
12059
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12060
+ FMF &= FPCI->getFastMathFlags();
12061
+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12062
+ }
12063
+ Type *Ty = VL.front()->getType();
12064
+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
12065
+ FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
12066
+ return FMACost < FMulPlusFAddCost ? FMACost : InstructionCost::getInvalid();
12067
+ }
12068
+
11990
12069
void BoUpSLP::transformNodes() {
11991
12070
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
11992
12071
BaseGraphSize = VectorizableTree.size();
@@ -12355,6 +12434,25 @@ void BoUpSLP::transformNodes() {
12355
12434
}
12356
12435
break;
12357
12436
}
12437
+ case Instruction::FSub:
12438
+ case Instruction::FAdd: {
12439
+ // Check if possible to convert (a*b)+c to fma.
12440
+ if (E.State != TreeEntry::Vectorize ||
12441
+ !E.getOperations().isAddSubLikeOp())
12442
+ break;
12443
+ if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
12444
+ .isValid())
12445
+ break;
12446
+ // This node is a fmuladd node.
12447
+ E.CombinedOp = TreeEntry::FMulAdd;
12448
+ TreeEntry *FMulEntry = getOperandEntry(&E, 0);
12449
+ if (FMulEntry->UserTreeIndex &&
12450
+ FMulEntry->State == TreeEntry::Vectorize) {
12451
+ // The FMul node is part of the combined fmuladd node.
12452
+ FMulEntry->State = TreeEntry::CombinedVectorize;
12453
+ }
12454
+ break;
12455
+ }
12358
12456
default:
12359
12457
break;
12360
12458
}
@@ -13587,6 +13685,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
13587
13685
}
13588
13686
return IntrinsicCost;
13589
13687
};
13688
+ auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
13689
+ Instruction *VI) {
13690
+ InstructionCost Cost = canConvertToFMA(VI, S, *DT, *DL, TTI, *TLI);
13691
+ return Cost;
13692
+ };
13590
13693
switch (ShuffleOrOp) {
13591
13694
case Instruction::PHI: {
13592
13695
// Count reused scalars.
@@ -13927,6 +14030,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
13927
14030
};
13928
14031
return GetCostDiff(GetScalarCost, GetVectorCost);
13929
14032
}
14033
+ case TreeEntry::FMulAdd: {
14034
+ auto GetScalarCost = [&](unsigned Idx) {
14035
+ if (isa<PoisonValue>(UniqueValues[Idx]))
14036
+ return InstructionCost(TTI::TCC_Free);
14037
+ return GetFMulAddCost(E->getOperations(),
14038
+ cast<Instruction>(UniqueValues[Idx]));
14039
+ };
14040
+ auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
14041
+ FastMathFlags FMF;
14042
+ FMF.set();
14043
+ for (Value *V : E->Scalars) {
14044
+ if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
14045
+ FMF &= FPCI->getFastMathFlags();
14046
+ if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
14047
+ FMF &= FPCIOp->getFastMathFlags();
14048
+ }
14049
+ }
14050
+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
14051
+ {VecTy, VecTy, VecTy}, FMF);
14052
+ InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
14053
+ return VecCost + CommonCost;
14054
+ };
14055
+ return GetCostDiff(GetScalarCost, GetVectorCost);
14056
+ }
13930
14057
case Instruction::FNeg:
13931
14058
case Instruction::Add:
13932
14059
case Instruction::FAdd:
@@ -13964,8 +14091,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
13964
14091
}
13965
14092
TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
13966
14093
TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
13967
- return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
13968
- Op1Info, Op2Info, Operands);
14094
+ InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
14095
+ ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
14096
+ if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
14097
+ I && (ShuffleOrOp == Instruction::FAdd ||
14098
+ ShuffleOrOp == Instruction::FSub)) {
14099
+ InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
14100
+ if (IntrinsicCost.isValid())
14101
+ ScalarCost = IntrinsicCost;
14102
+ }
14103
+ return ScalarCost;
13969
14104
};
13970
14105
auto GetVectorCost = [=](InstructionCost CommonCost) {
13971
14106
if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
@@ -22594,11 +22729,21 @@ class HorizontalReduction {
22594
22729
/// Try to find a reduction tree.
22595
22730
bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
22596
22731
ScalarEvolution &SE, const DataLayout &DL,
22597
- const TargetLibraryInfo &TLI) {
22732
+ const TargetLibraryInfo &TLI,
22733
+ DominatorTree &DT, TargetTransformInfo &TTI) {
22598
22734
RdxKind = HorizontalReduction::getRdxKind(Root);
22599
22735
if (!isVectorizable(RdxKind, Root))
22600
22736
return false;
22601
22737
22738
+ // FMA reduction root - skip.
22739
+ auto CheckForFMA = [&](Instruction *I) {
22740
+ return RdxKind == RecurKind::FAdd &&
22741
+ canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
22742
+ .isValid();
22743
+ };
22744
+ if (CheckForFMA(Root))
22745
+ return false;
22746
+
22602
22747
// Analyze "regular" integer/FP types for reductions - no target-specific
22603
22748
// types or pointers.
22604
22749
Type *Ty = Root->getType();
@@ -22636,7 +22781,7 @@ class HorizontalReduction {
22636
22781
// Also, do not try to reduce const values, if the operation is not
22637
22782
// foldable.
22638
22783
if (!EdgeInst || Level > RecursionMaxDepth ||
22639
- getRdxKind(EdgeInst) != RdxKind ||
22784
+ getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
22640
22785
IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
22641
22786
!hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
22642
22787
!isVectorizable(RdxKind, EdgeInst) ||
@@ -24205,13 +24350,13 @@ bool SLPVectorizerPass::vectorizeHorReduction(
24205
24350
Stack.emplace(SelectRoot(), 0);
24206
24351
SmallPtrSet<Value *, 8> VisitedInstrs;
24207
24352
bool Res = false;
24208
- auto && TryToReduce = [this, &R](Instruction *Inst) -> Value * {
24353
+ auto TryToReduce = [this, &R, TTI = TTI ](Instruction *Inst) -> Value * {
24209
24354
if (R.isAnalyzedReductionRoot(Inst))
24210
24355
return nullptr;
24211
24356
if (!isReductionCandidate(Inst))
24212
24357
return nullptr;
24213
24358
HorizontalReduction HorRdx;
24214
- if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
24359
+ if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI ))
24215
24360
return nullptr;
24216
24361
return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
24217
24362
};
@@ -24277,6 +24422,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
24277
24422
24278
24423
if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
24279
24424
return false;
24425
+ // Skip potential FMA candidates.
24426
+ if ((I->getOpcode() == Instruction::FAdd ||
24427
+ I->getOpcode() == Instruction::FSub) &&
24428
+ canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
24429
+ .isValid())
24430
+ return false;
24280
24431
24281
24432
Value *P = I->getParent();
24282
24433
0 commit comments