@@ -299,7 +299,9 @@ InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) {
299299 UI = &WidenMem->getIngredient ();
300300
301301 InstructionCost RecipeCost;
302- if (UI && Ctx.skipCostComputation (UI, VF.isVector ())) {
302+ if ((UI && Ctx.skipCostComputation (UI, VF.isVector ())) ||
303+ (Ctx.FoldedRecipes .contains (VF) &&
304+ Ctx.FoldedRecipes .at (VF).contains (this ))) {
303305 RecipeCost = 0 ;
304306 } else {
305307 RecipeCost = computeCost (VF, Ctx);
@@ -2188,30 +2190,143 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
21882190 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
21892191 unsigned Opcode = RdxDesc.getOpcode ();
21902192
2191- // TODO: Support any-of and in-loop reductions.
2193+ // TODO: Support any-of reductions.
21922194 assert (
21932195 (!RecurrenceDescriptor::isAnyOfRecurrenceKind (RdxKind) ||
21942196 ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
21952197 " Any-of reduction not implemented in VPlan-based cost model currently." );
2196- assert (
2197- (!cast<VPReductionPHIRecipe>(getOperand (0 ))->isInLoop () ||
2198- ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2199- " In-loop reduction not implemented in VPlan-based cost model currently." );
22002198
22012199 assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
22022200 " Inferred type and recurrence type mismatch." );
22032201
2204- // Cost = Reduction cost + BinOp cost
2205- InstructionCost Cost =
2202+ // BaseCost = Reduction cost + BinOp cost
2203+ InstructionCost BaseCost =
22062204 Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, CostKind);
22072205 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
22082206 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
2209- return Cost + Ctx.TTI .getMinMaxReductionCost (
2210- Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2207+ BaseCost += Ctx.TTI .getMinMaxReductionCost (
2208+ Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2209+ } else {
2210+ BaseCost += Ctx.TTI .getArithmeticReductionCost (
2211+ Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
22112212 }
22122213
2213- return Cost + Ctx.TTI .getArithmeticReductionCost (
2214- Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2214+ using namespace llvm ::VPlanPatternMatch;
2215+ auto GetMulAccReductionCost =
2216+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2217+ VPValue *A, *B;
2218+ InstructionCost InnerExt0Cost = 0 ;
2219+ InstructionCost InnerExt1Cost = 0 ;
2220+ InstructionCost ExtCost = 0 ;
2221+ InstructionCost MulCost = 0 ;
2222+
2223+ VectorType *SrcVecTy = VectorTy;
2224+ Type *InnerExt0Ty;
2225+ Type *InnerExt1Ty;
2226+ Type *MaxInnerExtTy;
2227+ bool IsUnsigned = true ;
2228+ bool HasOuterExt = false ;
2229+
2230+ auto *Ext = dyn_cast_if_present<VPWidenCastRecipe>(
2231+ Red->getVecOp ()->getDefiningRecipe ());
2232+ VPRecipeBase *Mul;
2233+ // Try to match outer extend reduce.add(ext(...))
2234+ if (Ext && match (Ext, m_ZExtOrSExt (m_VPValue ())) &&
2235+ cast<VPWidenCastRecipe>(Ext)->getNumUsers () == 1 ) {
2236+ IsUnsigned =
2237+ Ext->getOpcode () == Instruction::CastOps::ZExt ? true : false ;
2238+ ExtCost = Ext->computeCost (VF, Ctx);
2239+ Mul = Ext->getOperand (0 )->getDefiningRecipe ();
2240+ HasOuterExt = true ;
2241+ } else {
2242+ Mul = Red->getVecOp ()->getDefiningRecipe ();
2243+ }
2244+
2245+ // Match reduce.add(mul())
2246+ if (Mul && match (Mul, m_Mul (m_VPValue (A), m_VPValue (B))) &&
2247+ cast<VPWidenRecipe>(Mul)->getNumUsers () == 1 ) {
2248+ MulCost = cast<VPWidenRecipe>(Mul)->computeCost (VF, Ctx);
2249+ auto *InnerExt0 =
2250+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
2251+ auto *InnerExt1 =
2252+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
2253+ bool HasInnerExt = false ;
2254+ // Try to match inner extends.
2255+ if (InnerExt0 && InnerExt1 &&
2256+ match (InnerExt0, m_ZExtOrSExt (m_VPValue ())) &&
2257+ match (InnerExt1, m_ZExtOrSExt (m_VPValue ())) &&
2258+ InnerExt0->getOpcode () == InnerExt1->getOpcode () &&
2259+ (InnerExt0->getNumUsers () > 0 &&
2260+ !InnerExt0->hasMoreThanOneUniqueUser ()) &&
2261+ (InnerExt1->getNumUsers () > 0 &&
2262+ !InnerExt1->hasMoreThanOneUniqueUser ())) {
2263+ InnerExt0Cost = InnerExt0->computeCost (VF, Ctx);
2264+ InnerExt1Cost = InnerExt1->computeCost (VF, Ctx);
2265+ Type *InnerExt0Ty = Ctx.Types .inferScalarType (InnerExt0->getOperand (0 ));
2266+ Type *InnerExt1Ty = Ctx.Types .inferScalarType (InnerExt1->getOperand (0 ));
2267+ Type *MaxInnerExtTy = InnerExt0Ty->getIntegerBitWidth () >
2268+ InnerExt1Ty->getIntegerBitWidth ()
2269+ ? InnerExt0Ty
2270+ : InnerExt1Ty;
2271+ SrcVecTy = cast<VectorType>(ToVectorTy (MaxInnerExtTy, VF));
2272+ IsUnsigned = true ;
2273+ HasInnerExt = true ;
2274+ }
2275+ InstructionCost MulAccRedCost = Ctx.TTI .getMulAccReductionCost (
2276+ IsUnsigned, ElementTy, SrcVecTy, CostKind);
2277+ // Check if folding ext/mul into MulAccReduction is profitable.
2278+ if (MulAccRedCost.isValid () &&
2279+ MulAccRedCost <
2280+ ExtCost + MulCost + InnerExt0Cost + InnerExt1Cost + BaseCost) {
2281+ if (HasInnerExt) {
2282+ Ctx.FoldedRecipes [VF].insert (InnerExt0);
2283+ Ctx.FoldedRecipes [VF].insert (InnerExt1);
2284+ }
2285+ Ctx.FoldedRecipes [VF].insert (Mul);
2286+ if (HasOuterExt)
2287+ Ctx.FoldedRecipes [VF].insert (Ext);
2288+ return MulAccRedCost;
2289+ }
2290+ }
2291+ return InstructionCost::getInvalid ();
2292+ };
2293+
2294+ // Match reduce(ext(...))
2295+ auto GetExtendedReductionCost =
2296+ [&](const VPReductionRecipe *Red) -> InstructionCost {
2297+ VPValue *VecOp = Red->getVecOp ();
2298+ VPValue *A;
2299+ if (match (VecOp, m_ZExtOrSExt (m_VPValue (A))) && VecOp->getNumUsers () == 1 ) {
2300+ VPWidenCastRecipe *Ext =
2301+ cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe ());
2302+ bool IsUnsigned = Ext->getOpcode () == Instruction::CastOps::ZExt;
2303+ InstructionCost ExtCost = Ext->computeCost (VF, Ctx);
2304+ auto *ExtVecTy =
2305+ cast<VectorType>(ToVectorTy (Ctx.Types .inferScalarType (A), VF));
2306+ InstructionCost ExtendedRedCost = Ctx.TTI .getExtendedReductionCost (
2307+ Opcode, IsUnsigned, ElementTy, ExtVecTy, RdxDesc.getFastMathFlags (),
2308+ CostKind);
2309+ // Check if folding ext into ExtendedReduction is profitable.
2310+ if (ExtendedRedCost.isValid () && ExtendedRedCost < ExtCost + BaseCost) {
2311+ Ctx.FoldedRecipes [VF].insert (Ext);
2312+ return ExtendedRedCost;
2313+ }
2314+ }
2315+ return InstructionCost::getInvalid ();
2316+ };
2317+
2318+ // Match MulAccReduction patterns.
2319+ InstructionCost MulAccCost = GetMulAccReductionCost (this );
2320+ if (MulAccCost.isValid ())
2321+ return MulAccCost;
2322+
2323+ // Match ExtendedReduction patterns.
2324+ InstructionCost ExtendedCost = GetExtendedReductionCost (this );
2325+ if (ExtendedCost.isValid ())
2326+ return ExtendedCost;
2327+
2328+ // Default cost.
2329+ return BaseCost;
22152330}
22162331
22172332#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments