@@ -22730,21 +22730,11 @@ class HorizontalReduction {
2273022730 /// Try to find a reduction tree.
2273122731 bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
2273222732 ScalarEvolution &SE, const DataLayout &DL,
22733- const TargetLibraryInfo &TLI,
22734- DominatorTree &DT, TargetTransformInfo &TTI) {
22733+ const TargetLibraryInfo &TLI) {
2273522734 RdxKind = HorizontalReduction::getRdxKind(Root);
2273622735 if (!isVectorizable(RdxKind, Root))
2273722736 return false;
2273822737
22739- // FMA reduction root - skip.
22740- auto CheckForFMA = [&](Instruction *I) {
22741- return RdxKind == RecurKind::FAdd &&
22742- canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
22743- .isValid();
22744- };
22745- if (CheckForFMA(Root))
22746- return false;
22747-
2274822738 // Analyze "regular" integer/FP types for reductions - no target-specific
2274922739 // types or pointers.
2275022740 Type *Ty = Root->getType();
@@ -22782,7 +22772,7 @@ class HorizontalReduction {
2278222772 // Also, do not try to reduce const values, if the operation is not
2278322773 // foldable.
2278422774 if (!EdgeInst || Level > RecursionMaxDepth ||
22785- getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
22775+ getRdxKind(EdgeInst) != RdxKind ||
2278622776 IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
2278722777 !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
2278822778 !isVectorizable(RdxKind, EdgeInst) ||
@@ -22901,7 +22891,8 @@ class HorizontalReduction {
2290122891
2290222892 /// Attempt to vectorize the tree found by matchAssociativeReduction.
2290322893 Value *tryToReduce(BoUpSLP &V, const DataLayout &DL, TargetTransformInfo *TTI,
22904- const TargetLibraryInfo &TLI, AssumptionCache *AC) {
22894+ const TargetLibraryInfo &TLI, AssumptionCache *AC,
22895+ DominatorTree &DT) {
2290522896 constexpr unsigned RegMaxNumber = 4;
2290622897 constexpr unsigned RedValsMaxNumber = 128;
2290722898 // If there are a sufficient number of reduction values, reduce
@@ -23302,7 +23293,7 @@ class HorizontalReduction {
2330223293
2330323294 // Estimate cost.
2330423295 InstructionCost ReductionCost =
23305- getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
23296+ getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V, DT, DL, TLI );
2330623297 InstructionCost Cost = V.getTreeCost(VL, ReductionCost);
2330723298 LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
2330823299 << " for reduction\n");
@@ -23607,7 +23598,9 @@ class HorizontalReduction {
2360723598 InstructionCost getReductionCost(TargetTransformInfo *TTI,
2360823599 ArrayRef<Value *> ReducedVals,
2360923600 bool IsCmpSelMinMax, FastMathFlags FMF,
23610- const BoUpSLP &R) {
23601+ const BoUpSLP &R, DominatorTree &DT,
23602+ const DataLayout &DL,
23603+ const TargetLibraryInfo &TLI) {
2361123604 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2361223605 Type *ScalarTy = ReducedVals.front()->getType();
2361323606 unsigned ReduxWidth = ReducedVals.size();
@@ -23632,6 +23625,22 @@ class HorizontalReduction {
2363223625 for (User *U : RdxVal->users()) {
2363323626 auto *RdxOp = cast<Instruction>(U);
2363423627 if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) {
23628+ if (RdxKind == RecurKind::FAdd) {
23629+ InstructionCost FMACost = canConvertToFMA(
23630+ RdxOp, getSameOpcode(RdxOp, TLI), DT, DL, *TTI, TLI);
23631+ if (FMACost.isValid()) {
23632+ LLVM_DEBUG(dbgs() << "FMA cost: " << FMACost << "\n");
23633+ if (auto *I = dyn_cast<Instruction>(RdxVal)) {
23634+ // Also, exclude scalar fmul cost.
23635+ InstructionCost FMulCost =
23636+ TTI->getInstructionCost(I, CostKind);
23637+ LLVM_DEBUG(dbgs() << "Minus FMul cost: " << FMulCost << "\n");
23638+ FMACost -= FMulCost;
23639+ }
23640+ ScalarCost += FMACost;
23641+ continue;
23642+ }
23643+ }
2363523644 ScalarCost += TTI->getInstructionCost(RdxOp, CostKind);
2363623645 continue;
2363723646 }
@@ -23696,8 +23705,42 @@ class HorizontalReduction {
2369623705 auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
2369723706 std::make_pair(RedTy, true));
2369823707 VectorType *RVecTy = getWidenedType(RType, ReduxWidth);
23699- VectorCost +=
23700- TTI->getArithmeticInstrCost(RdxOpcode, RVecTy, CostKind);
23708+ InstructionCost FMACost = InstructionCost::getInvalid();
23709+ if (RdxKind == RecurKind::FAdd) {
23710+ // Check if the reduction operands can be converted to FMA.
23711+ SmallVector<Value *> Ops;
23712+ FastMathFlags FMF;
23713+ FMF.set();
23714+ for (Value *RdxVal : ReducedVals) {
23715+ if (!RdxVal->hasOneUse()) {
23716+ Ops.clear();
23717+ break;
23718+ }
23719+ if (auto *FPCI = dyn_cast<FPMathOperator>(RdxVal))
23720+ FMF &= FPCI->getFastMathFlags();
23721+ Ops.push_back(RdxVal->user_back());
23722+ }
23723+ FMACost = canConvertToFMA(
23724+ Ops, getSameOpcode(Ops, TLI), DT, DL, *TTI, TLI);
23725+ if (FMACost.isValid()) {
23726+ // Calculate actual FMAD cost.
23727+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, RVecTy,
23728+ {RVecTy, RVecTy, RVecTy}, FMF);
23729+ FMACost = TTI->getIntrinsicInstrCost(ICA, CostKind);
23730+
23731+ LLVM_DEBUG(dbgs() << "Vector FMA cost: " << FMACost << "\n");
23732+ // Also, exclude vector fmul cost.
23733+ InstructionCost FMulCost = TTI->getArithmeticInstrCost(
23734+ Instruction::FMul, RVecTy, CostKind);
23735+ LLVM_DEBUG(dbgs() << "Minus vector FMul cost: " << FMulCost << "\n");
23736+ FMACost -= FMulCost;
23737+ }
23738+ }
23739+ if (FMACost.isValid())
23740+ VectorCost += FMACost;
23741+ else
23742+ VectorCost +=
23743+ TTI->getArithmeticInstrCost(RdxOpcode, RVecTy, CostKind);
2370123744 if (RType != RedTy) {
2370223745 unsigned Opcode = Instruction::Trunc;
2370323746 if (RedTy->getScalarSizeInBits() > RType->getScalarSizeInBits())
@@ -24357,9 +24400,9 @@ bool SLPVectorizerPass::vectorizeHorReduction(
2435724400 if (!isReductionCandidate(Inst))
2435824401 return nullptr;
2435924402 HorizontalReduction HorRdx;
24360- if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI ))
24403+ if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
2436124404 return nullptr;
24362- return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
24405+ return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC, *DT );
2436324406 };
2436424407 auto TryAppendToPostponedInsts = [&](Instruction *FutureSeed) {
2436524408 if (TryOperandsAsNewSeeds && FutureSeed == Root) {
@@ -24504,7 +24547,7 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
2450424547 if (RedCost >= ScalarCost)
2450524548 return false;
2450624549
24507- return HorRdx.tryToReduce(R, *DL, &TTI, *TLI, AC) != nullptr;
24550+ return HorRdx.tryToReduce(R, *DL, &TTI, *TLI, AC, *DT ) != nullptr;
2450824551 };
2450924552 if (Candidates.size() == 1)
2451024553 return TryToReduce(I, {Op0, Op1}) || tryToVectorizeList({Op0, Op1}, R);
0 commit comments