@@ -3883,6 +3883,7 @@ class BoUpSLP {
3883
3883
enum CombinedOpcode {
3884
3884
NotCombinedOp = -1,
3885
3885
MinMax = Instruction::OtherOpsEnd + 1,
3886
+ FMulAdd,
3886
3887
};
3887
3888
CombinedOpcode CombinedOp = NotCombinedOp;
3888
3889
@@ -4033,6 +4034,9 @@ class BoUpSLP {
4033
4034
/// Returns true if any scalar in the list is a copyable element.
4034
4035
bool hasCopyableElements() const { return !CopyableElements.empty(); }
4035
4036
4037
+ /// Returns the state of the operations.
4038
+ const InstructionsState &getOperations() const { return S; }
4039
+
4036
4040
/// When ReuseReorderShuffleIndices is empty it just returns position of \p
4037
4041
/// V within vector of Scalars. Otherwise, try to remap on its reuse index.
4038
4042
unsigned findLaneForValue(Value *V) const {
@@ -11987,6 +11991,89 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
11987
11991
}
11988
11992
}
11989
11993
11994
+ /// Check if we can convert fadd/fsub sequence to FMAD.
11995
+ /// \returns Cost of the FMAD, if conversion is possible, invalid cost otherwise.
11996
+ static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
11997
+ const InstructionsState &S,
11998
+ DominatorTree &DT, const DataLayout &DL,
11999
+ TargetTransformInfo &TTI,
12000
+ const TargetLibraryInfo &TLI) {
12001
+ assert(all_of(VL,
12002
+ [](Value *V) {
12003
+ return V->getType()->getScalarType()->isFloatingPointTy();
12004
+ }) &&
12005
+ "Can only convert to FMA for floating point types");
12006
+ assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
12007
+
12008
+ auto CheckForContractable = [&](ArrayRef<Value *> VL) {
12009
+ FastMathFlags FMF;
12010
+ FMF.set();
12011
+ for (Value *V : VL) {
12012
+ auto *I = dyn_cast<Instruction>(V);
12013
+ if (!I)
12014
+ continue;
12015
+ if (S.isCopyableElement(I))
12016
+ continue;
12017
+ Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
12018
+ if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
12019
+ continue;
12020
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12021
+ FMF &= FPCI->getFastMathFlags();
12022
+ }
12023
+ return FMF.allowContract();
12024
+ };
12025
+ if (!CheckForContractable(VL))
12026
+ return InstructionCost::getInvalid();
12027
+ // fmul also should be contractable
12028
+ InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
12029
+ SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
12030
+
12031
+ InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
12032
+ if (!OpS.valid())
12033
+ return InstructionCost::getInvalid();
12034
+
12035
+ if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
12036
+ return InstructionCost::getInvalid();
12037
+ if (!CheckForContractable(Operands.front()))
12038
+ return InstructionCost::getInvalid();
12039
+ // Compare the costs.
12040
+ InstructionCost FMulPlusFAddCost = 0;
12041
+ InstructionCost FMACost = 0;
12042
+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
12043
+ FastMathFlags FMF;
12044
+ FMF.set();
12045
+ for (Value *V : VL) {
12046
+ auto *I = dyn_cast<Instruction>(V);
12047
+ if (!I)
12048
+ continue;
12049
+ if (!S.isCopyableElement(I))
12050
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12051
+ FMF &= FPCI->getFastMathFlags();
12052
+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12053
+ }
12054
+ unsigned NumOps = 0;
12055
+ for (auto [V, Op] : zip(VL, Operands.front())) {
12056
+ if (S.isCopyableElement(V))
12057
+ continue;
12058
+ auto *I = dyn_cast<Instruction>(Op);
12059
+ if (!I || !I->hasOneUse() || OpS.isCopyableElement(I)) {
12060
+ if (auto *OpI = dyn_cast<Instruction>(V))
12061
+ FMACost += TTI.getInstructionCost(OpI, CostKind);
12062
+ if (I)
12063
+ FMACost += TTI.getInstructionCost(I, CostKind);
12064
+ continue;
12065
+ }
12066
+ ++NumOps;
12067
+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12068
+ FMF &= FPCI->getFastMathFlags();
12069
+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12070
+ }
12071
+ Type *Ty = VL.front()->getType();
12072
+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
12073
+ FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
12074
+ return FMACost < FMulPlusFAddCost ? FMACost : InstructionCost::getInvalid();
12075
+ }
12076
+
11990
12077
void BoUpSLP::transformNodes() {
11991
12078
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
11992
12079
BaseGraphSize = VectorizableTree.size();
@@ -12355,6 +12442,25 @@ void BoUpSLP::transformNodes() {
12355
12442
}
12356
12443
break;
12357
12444
}
12445
+ case Instruction::FSub:
12446
+ case Instruction::FAdd: {
12447
+ // Check if possible to convert (a*b)+c to fma.
12448
+ if (E.State != TreeEntry::Vectorize ||
12449
+ !E.getOperations().isAddSubLikeOp())
12450
+ break;
12451
+ if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
12452
+ .isValid())
12453
+ break;
12454
+ // This node is a fmuladd node.
12455
+ E.CombinedOp = TreeEntry::FMulAdd;
12456
+ TreeEntry *FMulEntry = getOperandEntry(&E, 0);
12457
+ if (FMulEntry->UserTreeIndex &&
12458
+ FMulEntry->State == TreeEntry::Vectorize) {
12459
+ // The FMul node is part of the combined fmuladd node.
12460
+ FMulEntry->State = TreeEntry::CombinedVectorize;
12461
+ }
12462
+ break;
12463
+ }
12358
12464
default:
12359
12465
break;
12360
12466
}
@@ -13587,6 +13693,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
13587
13693
}
13588
13694
return IntrinsicCost;
13589
13695
};
13696
+ auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
13697
+ Instruction *VI) {
13698
+ InstructionCost Cost = canConvertToFMA(VI, S, *DT, *DL, TTI, *TLI);
13699
+ return Cost;
13700
+ };
13590
13701
switch (ShuffleOrOp) {
13591
13702
case Instruction::PHI: {
13592
13703
// Count reused scalars.
@@ -13927,6 +14038,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
13927
14038
};
13928
14039
return GetCostDiff(GetScalarCost, GetVectorCost);
13929
14040
}
14041
+ case TreeEntry::FMulAdd: {
14042
+ auto GetScalarCost = [&](unsigned Idx) {
14043
+ if (isa<PoisonValue>(UniqueValues[Idx]))
14044
+ return InstructionCost(TTI::TCC_Free);
14045
+ return GetFMulAddCost(E->getOperations(),
14046
+ cast<Instruction>(UniqueValues[Idx]));
14047
+ };
14048
+ auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
14049
+ FastMathFlags FMF;
14050
+ FMF.set();
14051
+ for (Value *V : E->Scalars) {
14052
+ if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
14053
+ FMF &= FPCI->getFastMathFlags();
14054
+ if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
14055
+ FMF &= FPCIOp->getFastMathFlags();
14056
+ }
14057
+ }
14058
+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
14059
+ {VecTy, VecTy, VecTy}, FMF);
14060
+ InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
14061
+ return VecCost + CommonCost;
14062
+ };
14063
+ return GetCostDiff(GetScalarCost, GetVectorCost);
14064
+ }
13930
14065
case Instruction::FNeg:
13931
14066
case Instruction::Add:
13932
14067
case Instruction::FAdd:
@@ -13964,8 +14099,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
13964
14099
}
13965
14100
TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
13966
14101
TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
13967
- return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
13968
- Op1Info, Op2Info, Operands);
14102
+ InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
14103
+ ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
14104
+ if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
14105
+ I && (ShuffleOrOp == Instruction::FAdd ||
14106
+ ShuffleOrOp == Instruction::FSub)) {
14107
+ InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
14108
+ if (IntrinsicCost.isValid())
14109
+ ScalarCost = IntrinsicCost;
14110
+ }
14111
+ return ScalarCost;
13969
14112
};
13970
14113
auto GetVectorCost = [=](InstructionCost CommonCost) {
13971
14114
if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
@@ -24205,7 +24348,7 @@ bool SLPVectorizerPass::vectorizeHorReduction(
24205
24348
Stack.emplace(SelectRoot(), 0);
24206
24349
SmallPtrSet<Value *, 8> VisitedInstrs;
24207
24350
bool Res = false;
24208
- auto && TryToReduce = [this, &R](Instruction *Inst) -> Value * {
24351
+ auto TryToReduce = [this, &R, TTI = TTI ](Instruction *Inst) -> Value * {
24209
24352
if (R.isAnalyzedReductionRoot(Inst))
24210
24353
return nullptr;
24211
24354
if (!isReductionCandidate(Inst))
@@ -24277,6 +24420,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
24277
24420
24278
24421
if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
24279
24422
return false;
24423
+ // Skip potential FMA candidates.
24424
+ if ((I->getOpcode() == Instruction::FAdd ||
24425
+ I->getOpcode() == Instruction::FSub) &&
24426
+ canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
24427
+ .isValid())
24428
+ return false;
24280
24429
24281
24430
Value *P = I->getParent();
24282
24431
0 commit comments