@@ -270,7 +270,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
270270 UI = &WidenMem->getIngredient ();
271271
272272 InstructionCost RecipeCost;
273- if (UI && Ctx.skipCostComputation (UI, VF.isVector ())) {
273+ if ((UI && Ctx.skipCostComputation (UI, VF.isVector ())) ||
274+ (Ctx.FoldedRecipes .contains (VF) &&
275+ Ctx.FoldedRecipes .at (VF).contains (this ))) {
274276 RecipeCost = 0 ;
275277 } else {
276278 RecipeCost = computeCost (VF, Ctx);
@@ -2187,30 +2189,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
21872189 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
21882190 unsigned Opcode = RdxDesc.getOpcode ();
21892191
2190- // TODO: Support any-of and in-loop reductions.
2192+ // TODO: Support any-of reductions.
21912193 assert (
21922194 (!RecurrenceDescriptor::isAnyOfRecurrenceKind (RdxKind) ||
21932195 ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
21942196 " Any-of reduction not implemented in VPlan-based cost model currently." );
2195- assert (
2196- (!cast<VPReductionPHIRecipe>(getOperand (0 ))->isInLoop () ||
2197- ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2198- " In-loop reduction not implemented in VPlan-based cost model currently." );
21992197
22002198 assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
22012199 " Inferred type and recurrence type mismatch." );
22022200
2203- // Cost = Reduction cost + BinOp cost
2204- InstructionCost Cost =
2201+ // BaseCost = Reduction cost + BinOp cost
2202+ InstructionCost BaseCost =
22052203 Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, CostKind);
22062204 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
22072205 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
2208- return Cost + Ctx.TTI .getMinMaxReductionCost (
2209- Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2206+ BaseCost += Ctx.TTI .getMinMaxReductionCost (
2207+ Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2208+ } else {
2209+ BaseCost += Ctx.TTI .getArithmeticReductionCost (
2210+ Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
22102211 }
22112212
2212- return Cost + Ctx.TTI .getArithmeticReductionCost (
2213- Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2213+ using namespace llvm ::VPlanPatternMatch;
2214+ auto GetMulAccReductionCost =
2215+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2216+ VPValue *A, *B;
2217+ InstructionCost InnerExt0Cost = 0 ;
2218+ InstructionCost InnerExt1Cost = 0 ;
2219+ InstructionCost ExtCost = 0 ;
2220+ InstructionCost MulCost = 0 ;
2221+
2222+ VectorType *SrcVecTy = VectorTy;
2223+ Type *InnerExt0Ty;
2224+ Type *InnerExt1Ty;
2225+ Type *MaxInnerExtTy;
2226+ bool IsUnsigned = true ;
2227+ bool HasOuterExt = false ;
2228+
2229+ auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
2230+ Red->getVecOp ()->getDefiningRecipe ());
2231+ VPRecipeBase *Mul;
2232+ // Try to match outer extend reduce.add(ext(...))
2233+ if (Ext && match (Ext, m_ZExtOrSExt (m_VPValue ())) &&
2234+ cast<VPWidenCastRecipe>(Ext)->getNumUsers () == 1 ) {
2235+ IsUnsigned =
2236+ Ext->getOpcode () == Instruction::CastOps::ZExt ? true : false ;
2237+ ExtCost = Ext->computeCost (VF, Ctx);
2238+ Mul = Ext->getOperand (0 )->getDefiningRecipe ();
2239+ HasOuterExt = true ;
2240+ } else {
2241+ Mul = Red->getVecOp ()->getDefiningRecipe ();
2242+ }
2243+
2244+ // Match reduce.add(mul())
2245+ if (Mul && match (Mul, m_Mul (m_VPValue (A), m_VPValue (B))) &&
2246+ cast<VPWidenRecipe>(Mul)->getNumUsers () == 1 ) {
2247+ MulCost = cast<VPWidenRecipe>(Mul)->computeCost (VF, Ctx);
2248+ auto *InnerExt0 =
2249+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
2250+ auto *InnerExt1 =
2251+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
2252+ bool HasInnerExt = false ;
2253+ // Try to match inner extends.
2254+ if (InnerExt0 && InnerExt1 &&
2255+ match (InnerExt0, m_ZExtOrSExt (m_VPValue ())) &&
2256+ match (InnerExt1, m_ZExtOrSExt (m_VPValue ())) &&
2257+ InnerExt0->getOpcode () == InnerExt1->getOpcode () &&
2258+ (InnerExt0->getNumUsers () > 0 &&
2259+ !InnerExt0->hasMoreThanOneUniqueUser ()) &&
2260+ (InnerExt1->getNumUsers () > 0 &&
2261+ !InnerExt1->hasMoreThanOneUniqueUser ())) {
2262+ InnerExt0Cost = InnerExt0->computeCost (VF, Ctx);
2263+ InnerExt1Cost = InnerExt1->computeCost (VF, Ctx);
2264+ Type *InnerExt0Ty = Ctx.Types .inferScalarType (InnerExt0->getOperand (0 ));
2265+ Type *InnerExt1Ty = Ctx.Types .inferScalarType (InnerExt1->getOperand (0 ));
2266+ Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth () >
2267+ InnerExt1Ty->getIntegerBitWidth ()
2268+ ? InnerExt0Ty
2269+ : InnerExt1Ty;
2270+ SrcVecTy = cast<VectorType>(ToVectorTy (MaxInnerExtTy, VF));
2271+ IsUnsigned = true ;
2272+ HasInnerExt = true ;
2273+ }
2274+ InstructionCost MulAccRedCost = Ctx.TTI .getMulAccReductionCost (
2275+ IsUnsigned, ElementTy, SrcVecTy, CostKind);
2276+ // Check if folding ext/mul into MulAccReduction is profitable.
2277+ if (MulAccRedCost.isValid () &&
2278+ MulAccRedCost <
2279+ ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
2280+ if (HasInnerExt) {
2281+ Ctx.FoldedRecipes [VF].insert (InnerExt0);
2282+ Ctx.FoldedRecipes [VF].insert (InnerExt1);
2283+ }
2284+ Ctx.FoldedRecipes [VF].insert (Mul);
2285+ if (HasOuterExt)
2286+ Ctx.FoldedRecipes [VF].insert (Ext);
2287+ return MulAccRedCost;
2288+ }
2289+ }
2290+ return InstructionCost::getInvalid ();
2291+ };
2292+
2293+ // Match reduce(ext(...))
2294+ auto GetExtendedReductionCost =
2295+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2296+ VPValue *VecOp = Red->getVecOp ();
2297+ VPValue *A;
2298+ if (match (VecOp, m_ZExtOrSExt (m_VPValue (A))) && VecOp->getNumUsers () == 1 ) {
2299+ VPWidenCastRecipe *Ext =
2300+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe ());
2301+ bool IsUnsigned = Ext->getOpcode () == Instruction::CastOps::ZExt;
2302+ InstructionCost ExtCost = Ext->computeCost (VF, Ctx);
2303+ auto *ExtVecTy =
2304+ cast<VectorType>(ToVectorTy (Ctx.Types .inferScalarType (A), VF));
2305+ InstructionCost ExtendedRedCost = Ctx.TTI .getExtendedReductionCost (
2306+ Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags (),
2307+ CostKind);
2308+ // Check if folding ext into ExtendedReduction is profitable.
2309+ if (ExtendedRedCost.isValid () && ExtendedRedCost < ExtCost + BaseCost) {
2310+ Ctx.FoldedRecipes [VF].insert (Ext);
2311+ return ExtendedRedCost;
2312+ }
2313+ }
2314+ return InstructionCost::getInvalid ();
2315+ };
2316+
2317+ // Match MulAccReduction patterns.
2318+ InstructionCost MulAccCost = GetMulAccReductionCost (this );
2319+ if (MulAccCost.isValid ())
2320+ return MulAccCost;
2321+
2322+ // Match ExtendedReduction patterns.
2323+ InstructionCost ExtendedCost = GetExtendedReductionCost (this );
2324+ if (ExtendedCost.isValid ())
2325+ return ExtendedCost;
2326+
2327+ // Default cost.
2328+ return BaseCost;
22142329}
22152330
22162331#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments