@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
270270 UI = &WidenMem->getIngredient ();
271271
272272 InstructionCost RecipeCost;
273- if (UI && Ctx.skipCostComputation (UI, VF.isVector ())) {
273+ if ((UI && Ctx.skipCostComputation (UI, VF.isVector ())) ||
274+ (Ctx.FoldedRecipes .contains (VF) &&
275+ Ctx.FoldedRecipes .at (VF).contains (this ))) {
274276 RecipeCost = 0 ;
275277 } else {
276278 RecipeCost = computeCost (VF, Ctx);
@@ -2185,30 +2187,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
21852187 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
21862188 unsigned Opcode = RdxDesc.getOpcode ();
21872189
2188- // TODO: Support any-of and in-loop reductions.
2190+ // TODO: Support any-of reductions.
21892191 assert (
21902192 (!RecurrenceDescriptor::isAnyOfRecurrenceKind (RdxKind) ||
21912193 ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
21922194 " Any-of reduction not implemented in VPlan-based cost model currently." );
2193- assert (
2194- (!cast<VPReductionPHIRecipe>(getOperand (0 ))->isInLoop () ||
2195- ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2196- " In-loop reduction not implemented in VPlan-based cost model currently." );
21972195
21982196 assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
21992197 " Inferred type and recurrence type mismatch." );
22002198
2201- // Cost = Reduction cost + BinOp cost
2202- InstructionCost Cost =
2199+ // BaseCost = Reduction cost + BinOp cost
2200+ InstructionCost BaseCost =
22032201 Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, CostKind);
22042202 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
22052203 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
2206- return Cost + Ctx.TTI .getMinMaxReductionCost (
2207- Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2204+ BaseCost += Ctx.TTI .getMinMaxReductionCost (
2205+ Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2206+ } else {
2207+ BaseCost += Ctx.TTI .getArithmeticReductionCost (
2208+ Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
22082209 }
22092210
2210- return Cost + Ctx.TTI .getArithmeticReductionCost (
2211- Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2211+ using namespace llvm ::VPlanPatternMatch;
2212+ auto GetMulAccReductionCost =
2213+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2214+ VPValue *A, *B;
2215+ InstructionCost InnerExt0Cost = 0 ;
2216+ InstructionCost InnerExt1Cost = 0 ;
2217+ InstructionCost ExtCost = 0 ;
2218+ InstructionCost MulCost = 0 ;
2219+
2220+ VectorType *SrcVecTy = VectorTy;
2221+ Type *InnerExt0Ty;
2222+ Type *InnerExt1Ty;
2223+ Type *MaxInnerExtTy;
2224+ bool IsUnsigned = true ;
2225+ bool HasOuterExt = false ;
2226+
2227+ auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
2228+ Red->getVecOp ()->getDefiningRecipe ());
2229+ VPRecipeBase *Mul;
2230+ // Try to match outer extend reduce.add(ext(...))
2231+ if (Ext && match (Ext, m_ZExtOrSExt (m_VPValue ())) &&
2232+ cast<VPWidenCastRecipe>(Ext)->getNumUsers () == 1 ) {
2233+ IsUnsigned =
2234+ Ext->getOpcode () == Instruction::CastOps::ZExt ? true : false ;
2235+ ExtCost = Ext->computeCost (VF, Ctx);
2236+ Mul = Ext->getOperand (0 )->getDefiningRecipe ();
2237+ HasOuterExt = true ;
2238+ } else {
2239+ Mul = Red->getVecOp ()->getDefiningRecipe ();
2240+ }
2241+
2242+ // Match reduce.add(mul())
2243+ if (Mul && match (Mul, m_Mul (m_VPValue (A), m_VPValue (B))) &&
2244+ cast<VPWidenRecipe>(Mul)->getNumUsers () == 1 ) {
2245+ MulCost = cast<VPWidenRecipe>(Mul)->computeCost (VF, Ctx);
2246+ auto *InnerExt0 =
2247+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
2248+ auto *InnerExt1 =
2249+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
2250+ bool HasInnerExt = false ;
2251+ // Try to match inner extends.
2252+ if (InnerExt0 && InnerExt1 &&
2253+ match (InnerExt0, m_ZExtOrSExt (m_VPValue ())) &&
2254+ match (InnerExt1, m_ZExtOrSExt (m_VPValue ())) &&
2255+ InnerExt0->getOpcode () == InnerExt1->getOpcode () &&
2256+ (InnerExt0->getNumUsers () > 0 &&
2257+ !InnerExt0->hasMoreThanOneUniqueUser ()) &&
2258+ (InnerExt1->getNumUsers () > 0 &&
2259+ !InnerExt1->hasMoreThanOneUniqueUser ())) {
2260+ InnerExt0Cost = InnerExt0->computeCost (VF, Ctx);
2261+ InnerExt1Cost = InnerExt1->computeCost (VF, Ctx);
2262+ Type *InnerExt0Ty = Ctx.Types .inferScalarType (InnerExt0->getOperand (0 ));
2263+ Type *InnerExt1Ty = Ctx.Types .inferScalarType (InnerExt1->getOperand (0 ));
2264+ Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth () >
2265+ InnerExt1Ty->getIntegerBitWidth ()
2266+ ? InnerExt0Ty
2267+ : InnerExt1Ty;
2268+ SrcVecTy = cast<VectorType>(ToVectorTy (MaxInnerExtTy, VF));
2269+ IsUnsigned = true ;
2270+ HasInnerExt = true ;
2271+ }
2272+ InstructionCost MulAccRedCost = Ctx.TTI .getMulAccReductionCost (
2273+ IsUnsigned, ElementTy, SrcVecTy, CostKind);
2274+ // Check if folding ext/mul into MulAccReduction is profitable.
2275+ if (MulAccRedCost.isValid () &&
2276+ MulAccRedCost <
2277+ ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
2278+ if (HasInnerExt) {
2279+ Ctx.FoldedRecipes [VF].insert (InnerExt0);
2280+ Ctx.FoldedRecipes [VF].insert (InnerExt1);
2281+ }
2282+ Ctx.FoldedRecipes [VF].insert (Mul);
2283+ if (HasOuterExt)
2284+ Ctx.FoldedRecipes [VF].insert (Ext);
2285+ return MulAccRedCost;
2286+ }
2287+ }
2288+ return InstructionCost::getInvalid ();
2289+ };
2290+
2291+ // Match reduce(ext(...))
2292+ auto GetExtendedReductionCost =
2293+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2294+ VPValue *VecOp = Red->getVecOp ();
2295+ VPValue *A;
2296+ if (match (VecOp, m_ZExtOrSExt (m_VPValue (A))) && VecOp->getNumUsers () == 1 ) {
2297+ VPWidenCastRecipe *Ext =
2298+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe ());
2299+ bool IsUnsigned = Ext->getOpcode () == Instruction::CastOps::ZExt;
2300+ InstructionCost ExtCost = Ext->computeCost (VF, Ctx);
2301+ auto *ExtVecTy =
2302+ cast<VectorType>(ToVectorTy (Ctx.Types .inferScalarType (A), VF));
2303+ InstructionCost ExtendedRedCost = Ctx.TTI .getExtendedReductionCost (
2304+ Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags (),
2305+ CostKind);
2306+ // Check if folding ext into ExtendedReduction is profitable.
2307+ if (ExtendedRedCost.isValid () && ExtendedRedCost < ExtCost + BaseCost) {
2308+ Ctx.FoldedRecipes [VF].insert (Ext);
2309+ return ExtendedRedCost;
2310+ }
2311+ }
2312+ return InstructionCost::getInvalid ();
2313+ };
2314+
2315+ // Match MulAccReduction patterns.
2316+ InstructionCost MulAccCost = GetMulAccReductionCost (this );
2317+ if (MulAccCost.isValid ())
2318+ return MulAccCost;
2319+
2320+ // Match ExtendedReduction patterns.
2321+ InstructionCost ExtendedCost = GetExtendedReductionCost (this );
2322+ if (ExtendedCost.isValid ())
2323+ return ExtendedCost;
2324+
2325+ // Default cost.
2326+ return BaseCost;
22122327}
22132328
22142329#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments