@@ -7659,32 +7659,38 @@ buildIntrinsicArgTypes(const CallInst *CI, const Intrinsic::ID ID,
76597659}
76607660
76617661/// Calculates the costs of vectorized intrinsic (if possible) and vectorized
7662- /// function (if possible) calls.
7662+ /// function (if possible) calls. Returns invalid cost for the corresponding
7663+ /// calls, if they cannot be vectorized/will be scalarized.
76637664static std::pair<InstructionCost, InstructionCost>
76647665getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
76657666 TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
76667667 ArrayRef<Type *> ArgTys) {
7667- Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
7668-
7669- // Calculate the cost of the scalar and vector calls.
7670- FastMathFlags FMF;
7671- if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
7672- FMF = FPCI->getFastMathFlags();
7673- IntrinsicCostAttributes CostAttrs(ID, VecTy, ArgTys, FMF);
7674- auto IntrinsicCost =
7675- TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
7676-
76777668 auto Shape = VFShape::get(CI->getFunctionType(),
76787669 ElementCount::getFixed(VecTy->getNumElements()),
76797670 false /*HasGlobalPred*/);
76807671 Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
7681- auto LibCost = IntrinsicCost ;
7672+ auto LibCost = InstructionCost::getInvalid() ;
76827673 if (!CI->isNoBuiltin() && VecFunc) {
76837674 // Calculate the cost of the vector library call.
76847675 // If the corresponding vector call is cheaper, return its cost.
76857676 LibCost =
76867677 TTI->getCallInstrCost(nullptr, VecTy, ArgTys, TTI::TCK_RecipThroughput);
76877678 }
7679+ Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
7680+
7681+ // Calculate the cost of the vector intrinsic call.
7682+ FastMathFlags FMF;
7683+ if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
7684+ FMF = FPCI->getFastMathFlags();
7685+ const InstructionCost ScalarLimit = 10000;
7686+ IntrinsicCostAttributes CostAttrs(ID, VecTy, ArgTys, FMF, nullptr,
7687+ LibCost.isValid() ? LibCost : ScalarLimit);
7688+ auto IntrinsicCost =
7689+ TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput);
7690+ if ((LibCost.isValid() && IntrinsicCost > LibCost) ||
7691+ (!LibCost.isValid() && IntrinsicCost > ScalarLimit))
7692+ IntrinsicCost = InstructionCost::getInvalid();
7693+
76887694 return {IntrinsicCost, LibCost};
76897695}
76907696
@@ -8028,6 +8034,12 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
80288034 return TreeEntry::NeedToGather;
80298035 }
80308036 }
8037+ SmallVector<Type *> ArgTys =
8038+ buildIntrinsicArgTypes(CI, ID, VL.size(), 0, TTI);
8039+ auto *VecTy = getWidenedType(S.getMainOp()->getType(), VL.size());
8040+ auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys);
8041+ if (!VecCallCosts.first.isValid() && !VecCallCosts.second.isValid())
8042+ return TreeEntry::NeedToGather;
80318043
80328044 return TreeEntry::Vectorize;
80338045 }
0 commit comments