@@ -3883,6 +3883,7 @@ class BoUpSLP {
38833883 enum CombinedOpcode {
38843884 NotCombinedOp = -1,
38853885 MinMax = Instruction::OtherOpsEnd + 1,
3886+ FMulAdd,
38863887 };
38873888 CombinedOpcode CombinedOp = NotCombinedOp;
38883889
@@ -4033,6 +4034,9 @@ class BoUpSLP {
40334034 /// Returns true if any scalar in the list is a copyable element.
40344035 bool hasCopyableElements() const { return !CopyableElements.empty(); }
40354036
4037+ /// Returns the state of the operations.
4038+ const InstructionsState &getOperations() const { return S; }
4039+
40364040 /// When ReuseReorderShuffleIndices is empty it just returns position of \p
40374041 /// V within vector of Scalars. Otherwise, try to remap on its reuse index.
40384042 unsigned findLaneForValue(Value *V) const {
@@ -11987,6 +11991,82 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
1198711991 }
1198811992}
1198911993
11994+ static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
11995+ const InstructionsState &S,
11996+ DominatorTree &DT, const DataLayout &DL,
11997+ TargetTransformInfo &TTI,
11998+ const TargetLibraryInfo &TLI) {
11999+ assert(all_of(VL,
12000+ [](Value *V) {
12001+ return V->getType()->getScalarType()->isFloatingPointTy();
12002+ }) &&
12003+ "Can only convert to FMA for floating point types");
12004+ assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
12005+
12006+ auto CheckForContractable = [&](ArrayRef<Value *> VL) {
12007+ FastMathFlags FMF;
12008+ FMF.set();
12009+ for (Value *V : VL) {
12010+ auto *I = dyn_cast<Instruction>(V);
12011+ if (!I)
12012+ continue;
12013+ // TODO: support for copyable elements.
12014+ Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
12015+ if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
12016+ continue;
12017+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12018+ FMF &= FPCI->getFastMathFlags();
12019+ }
12020+ return FMF.allowContract();
12021+ };
12022+ if (!CheckForContractable(VL))
12023+ return InstructionCost::getInvalid();
12024+ // fmul also should be contractable
12025+ InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
12026+ SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
12027+
12028+ InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
12029+ if (!OpS.valid())
12030+ return InstructionCost::getInvalid();
12031+ if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
12032+ return InstructionCost::getInvalid();
12033+ if (!CheckForContractable(Operands.front()))
12034+ return InstructionCost::getInvalid();
12035+ // Compare the costs.
12036+ InstructionCost FMulPlusFAddCost = 0;
12037+ InstructionCost FMACost = 0;
12038+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
12039+ FastMathFlags FMF;
12040+ FMF.set();
12041+ for (Value *V : VL) {
12042+ auto *I = dyn_cast<Instruction>(V);
12043+ if (!I)
12044+ continue;
12045+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12046+ FMF &= FPCI->getFastMathFlags();
12047+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12048+ }
12049+ unsigned NumOps = 0;
12050+ for (auto [V, Op] : zip(VL, Operands.front())) {
12051+ auto *I = dyn_cast<Instruction>(Op);
12052+ if (!I || !I->hasOneUse()) {
12053+ if (auto *OpI = dyn_cast<Instruction>(V))
12054+ FMACost += TTI.getInstructionCost(OpI, CostKind);
12055+ if (I)
12056+ FMACost += TTI.getInstructionCost(I, CostKind);
12057+ continue;
12058+ }
12059+ ++NumOps;
12060+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12061+ FMF &= FPCI->getFastMathFlags();
12062+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12063+ }
12064+ Type *Ty = VL.front()->getType();
12065+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
12066+ FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
12067+ return FMACost < FMulPlusFAddCost ? FMACost : InstructionCost::getInvalid();
12068+ }
12069+
1199012070void BoUpSLP::transformNodes() {
1199112071 constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1199212072 BaseGraphSize = VectorizableTree.size();
@@ -12355,6 +12435,25 @@ void BoUpSLP::transformNodes() {
1235512435 }
1235612436 break;
1235712437 }
12438+ case Instruction::FSub:
12439+ case Instruction::FAdd: {
12440+ // Check if possible to convert (a*b)+c to fma.
12441+ if (E.State != TreeEntry::Vectorize ||
12442+ !E.getOperations().isAddSubLikeOp())
12443+ break;
12444+ if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
12445+ .isValid())
12446+ break;
12447+ // This node is a fmuladd node.
12448+ E.CombinedOp = TreeEntry::FMulAdd;
12449+ TreeEntry *FMulEntry = getOperandEntry(&E, 0);
12450+ if (FMulEntry->UserTreeIndex &&
12451+ FMulEntry->State == TreeEntry::Vectorize) {
12452+ // The FMul node is part of the combined fmuladd node.
12453+ FMulEntry->State = TreeEntry::CombinedVectorize;
12454+ }
12455+ break;
12456+ }
1235812457 default:
1235912458 break;
1236012459 }
@@ -13587,6 +13686,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1358713686 }
1358813687 return IntrinsicCost;
1358913688 };
13689+ auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
13690+ Instruction *VI) {
13691+ InstructionCost Cost = canConvertToFMA(VI, S, *DT, *DL, TTI, *TLI);
13692+ return Cost;
13693+ };
1359013694 switch (ShuffleOrOp) {
1359113695 case Instruction::PHI: {
1359213696 // Count reused scalars.
@@ -13927,6 +14031,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1392714031 };
1392814032 return GetCostDiff(GetScalarCost, GetVectorCost);
1392914033 }
14034+ case TreeEntry::FMulAdd: {
14035+ auto GetScalarCost = [&](unsigned Idx) {
14036+ if (isa<PoisonValue>(UniqueValues[Idx]))
14037+ return InstructionCost(TTI::TCC_Free);
14038+ return GetFMulAddCost(E->getOperations(),
14039+ cast<Instruction>(UniqueValues[Idx]));
14040+ };
14041+ auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
14042+ FastMathFlags FMF;
14043+ FMF.set();
14044+ for (Value *V : E->Scalars) {
14045+ if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
14046+ FMF &= FPCI->getFastMathFlags();
14047+ if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
14048+ FMF &= FPCIOp->getFastMathFlags();
14049+ }
14050+ }
14051+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
14052+ {VecTy, VecTy, VecTy}, FMF);
14053+ InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
14054+ return VecCost + CommonCost;
14055+ };
14056+ return GetCostDiff(GetScalarCost, GetVectorCost);
14057+ }
1393014058 case Instruction::FNeg:
1393114059 case Instruction::Add:
1393214060 case Instruction::FAdd:
@@ -13964,8 +14092,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1396414092 }
1396514093 TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
1396614094 TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
13967- return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
13968- Op1Info, Op2Info, Operands);
14095+ InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
14096+ ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
14097+ if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
14098+ I && (ShuffleOrOp == Instruction::FAdd ||
14099+ ShuffleOrOp == Instruction::FSub)) {
14100+ InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
14101+ if (IntrinsicCost.isValid())
14102+ ScalarCost = IntrinsicCost;
14103+ }
14104+ return ScalarCost;
1396914105 };
1397014106 auto GetVectorCost = [=](InstructionCost CommonCost) {
1397114107 if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
@@ -22594,11 +22730,21 @@ class HorizontalReduction {
2259422730 /// Try to find a reduction tree.
2259522731 bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
2259622732 ScalarEvolution &SE, const DataLayout &DL,
22597- const TargetLibraryInfo &TLI) {
22733+ const TargetLibraryInfo &TLI,
22734+ DominatorTree &DT, TargetTransformInfo &TTI) {
2259822735 RdxKind = HorizontalReduction::getRdxKind(Root);
2259922736 if (!isVectorizable(RdxKind, Root))
2260022737 return false;
2260122738
22739+ // FMA reduction root - skip.
22740+ auto CheckForFMA = [&](Instruction *I) {
22741+ return RdxKind == RecurKind::FAdd &&
22742+ canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
22743+ .isValid();
22744+ };
22745+ if (CheckForFMA(Root))
22746+ return false;
22747+
2260222748 // Analyze "regular" integer/FP types for reductions - no target-specific
2260322749 // types or pointers.
2260422750 Type *Ty = Root->getType();
@@ -22636,7 +22782,7 @@ class HorizontalReduction {
2263622782 // Also, do not try to reduce const values, if the operation is not
2263722783 // foldable.
2263822784 if (!EdgeInst || Level > RecursionMaxDepth ||
22639- getRdxKind(EdgeInst) != RdxKind ||
22785+ getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
2264022786 IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
2264122787 !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
2264222788 !isVectorizable(RdxKind, EdgeInst) ||
@@ -24205,13 +24351,13 @@ bool SLPVectorizerPass::vectorizeHorReduction(
2420524351 Stack.emplace(SelectRoot(), 0);
2420624352 SmallPtrSet<Value *, 8> VisitedInstrs;
2420724353 bool Res = false;
24208- auto && TryToReduce = [this, &R](Instruction *Inst) -> Value * {
24354+ auto TryToReduce = [this, &R, TTI = TTI ](Instruction *Inst) -> Value * {
2420924355 if (R.isAnalyzedReductionRoot(Inst))
2421024356 return nullptr;
2421124357 if (!isReductionCandidate(Inst))
2421224358 return nullptr;
2421324359 HorizontalReduction HorRdx;
24214- if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
24360+ if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI ))
2421524361 return nullptr;
2421624362 return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
2421724363 };
@@ -24277,6 +24423,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
2427724423
2427824424 if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
2427924425 return false;
24426+ // Skip potential FMA candidates.
24427+ if ((I->getOpcode() == Instruction::FAdd ||
24428+ I->getOpcode() == Instruction::FSub) &&
24429+ canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
24430+ .isValid())
24431+ return false;
2428024432
2428124433 Value *P = I->getParent();
2428224434
0 commit comments