@@ -2022,6 +2022,11 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
20222022 State.set (this , NewRed, /* IsScalar*/ true );
20232023}
20242024
2025+ static bool isZExtOrSExt (Instruction::CastOps CastOpcode) {
2026+ return CastOpcode == Instruction::CastOps::ZExt ||
2027+ CastOpcode == Instruction::CastOps::SExt;
2028+ }
2029+
20252030InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
20262031 VPCostContext &Ctx) const {
20272032 RecurKind RdxKind = RdxDesc.getRecurrenceKind ();
@@ -2030,17 +2035,149 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
20302035 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
20312036 unsigned Opcode = RdxDesc.getOpcode ();
20322037
2033- // Cost = Reduction cost + BinOp cost
2034- InstructionCost Cost =
2035- Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, CostKind);
2038+ InstructionCost BaseCost;
20362039 if (RecurrenceDescriptor::isMinMaxRecurrenceKind (RdxKind)) {
20372040 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp (RdxKind);
2038- return Cost + Ctx.TTI .getMinMaxReductionCost (
2039- Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2041+ BaseCost = Ctx.TTI .getMinMaxReductionCost (
2042+ Id, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2043+ } else {
2044+ BaseCost = Ctx.TTI .getArithmeticReductionCost (
2045+ Opcode, VectorTy, RdxDesc.getFastMathFlags (), CostKind);
2046+ }
2047+
2048+ // For a call to the llvm.fmuladd intrinsic we need to add the cost of a
2049+ // normal fmul instruction to the cost of the fadd reduction.
2050+ if (RdxKind == RecurKind::FMulAdd)
2051+ BaseCost +=
2052+ Ctx.TTI .getArithmeticInstrCost (Instruction::FMul, VectorTy, CostKind);
2053+
2054+ // If we're using ordered reductions then we can just return the base cost
2055+ // here, since getArithmeticReductionCost calculates the full ordered
2056+ // reduction cost when FP reassociation is not allowed.
2057+ if (IsOrdered && Opcode == Instruction::FAdd)
2058+ return BaseCost;
2059+
2060+ // Special case for arm from D93476
2061+ // The reduction instruction can be substituted in following condition.
2062+ //
2063+ // %sa = sext <16 x i8> A to <16 x i32>
2064+ // %sb = sext <16 x i8> B to <16 x i32>
2065+ // %m = mul <16 x i32> %sa, %sb
2066+ // %r = vecreduce.add(%m)
2067+ // ->
2068+ // R = VMLADAV A, B
2069+ //
2070+ // There are other instructions for performing add reductions of
2071+ // v4i32/v8i16/v16i8 into i32 (VADDV), for doing the same with v4i32->i64
2072+ // (VADDLV) and for performing a v4i32/v8i16 MLA into an i64 (VMLALDAV).
2073+ //
2074+ // We are looking for a pattern of, and finding the minimal acceptable cost:
2075+ // reduce.add(ext(mul(ext(A), ext(B)))) or
2076+ // reduce(ext(A)) or
2077+ // reduce.add(mul(ext(A), ext(B))) or
2078+ // reduce.add(mul(A, B)) or
2079+ // reduce(A).
2080+
2081+ // Try to match reduce(ext(...))
2082+ auto *Ext = dyn_cast<VPWidenCastRecipe>(getVecOp ());
2083+ if (Ext && isZExtOrSExt (Ext->getOpcode ())) {
2084+ bool isUnsigned = Ext->getOpcode () == Instruction::CastOps::ZExt;
2085+
2086+ // Try to match reduce.add(ext(mul(...)))
2087+ auto *ExtTy = cast<VectorType>(
2088+ ToVectorTy (Ext->getOperand (0 )->getUnderlyingValue ()->getType (), VF));
2089+ auto *Mul = dyn_cast_if_present<VPWidenRecipe>(
2090+ Ext->getOperand (0 )->getDefiningRecipe ());
2091+ if (Mul && Mul->getOpcode () == Instruction::Mul &&
2092+ Opcode == Instruction::Add) {
2093+ auto *MulTy = cast<VectorType>(
2094+ ToVectorTy (Mul->getUnderlyingValue ()->getType (), VF));
2095+ auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (0 ));
2096+ auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (1 ));
2097+
2098+ // Match reduce.add(ext(mul(ext(A), ext(B))))
2099+ if (InnerExt0 && isZExtOrSExt (InnerExt0->getOpcode ()) && InnerExt1 &&
2100+ isZExtOrSExt (InnerExt1->getOpcode ()) &&
2101+ InnerExt0->getOpcode () == InnerExt1->getOpcode ()) {
2102+ Type *InnerExt0Ty =
2103+ InnerExt0->getOperand (0 )->getUnderlyingValue ()->getType ();
2104+ Type *InnerExt1Ty =
2105+ InnerExt1->getOperand (0 )->getUnderlyingValue ()->getType ();
2106+ // Get the largest type.
2107+ auto *MaxExtVecTy = cast<VectorType>(
2108+ ToVectorTy (InnerExt0Ty->getIntegerBitWidth () >
2109+ InnerExt1Ty->getIntegerBitWidth ()
2110+ ? InnerExt0Ty
2111+ : InnerExt1Ty,
2112+ VF));
2113+ InstructionCost RedCost = Ctx.TTI .getMulAccReductionCost (
2114+ isUnsigned, ElementTy, MaxExtVecTy, CostKind);
2115+ InstructionCost InnerExtCost =
2116+ Ctx.TTI .getCastInstrCost (InnerExt0->getOpcode (), MulTy, MaxExtVecTy,
2117+ TTI::CastContextHint::None, CostKind);
2118+ InstructionCost MulCost =
2119+ Ctx.TTI .getArithmeticInstrCost (Instruction::Mul, MulTy, CostKind);
2120+ InstructionCost ExtCost =
2121+ Ctx.TTI .getCastInstrCost (Ext->getOpcode (), VectorTy, ExtTy,
2122+ TTI::CastContextHint::None, CostKind);
2123+ if (RedCost.isValid () &&
2124+ RedCost < InnerExtCost * 2 + MulCost + ExtCost + BaseCost)
2125+ return RedCost;
2126+ }
2127+ }
2128+
2129+ // Match reduce(ext(A))
2130+ InstructionCost RedCost =
2131+ Ctx.TTI .getExtendedReductionCost (Opcode, isUnsigned, ElementTy, ExtTy,
2132+ RdxDesc.getFastMathFlags (), CostKind);
2133+ InstructionCost ExtCost =
2134+ Ctx.TTI .getCastInstrCost (Ext->getOpcode (), VectorTy, ExtTy,
2135+ TTI::CastContextHint::None, CostKind);
2136+ if (RedCost.isValid () && RedCost < RedCost + ExtCost)
2137+ return RedCost;
2138+ }
2139+
2140+ // Try to match reduce.add(mul(...))
2141+ auto *Mul =
2142+ dyn_cast_if_present<VPWidenRecipe>(getVecOp ()->getDefiningRecipe ());
2143+ if (Mul && Mul->getOpcode () == Instruction::Mul &&
2144+ Opcode == Instruction::Add) {
2145+ // Match reduce.add(mul(ext(A), ext(B)))
2146+ auto *InnerExt0 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (0 ));
2147+ auto *InnerExt1 = dyn_cast<VPWidenCastRecipe>(Mul->getOperand (1 ));
2148+ auto *MulTy =
2149+ cast<VectorType>(ToVectorTy (Mul->getUnderlyingValue ()->getType (), VF));
2150+ InstructionCost MulCost =
2151+ Ctx.TTI .getArithmeticInstrCost (Instruction::Mul, MulTy, CostKind);
2152+ if (InnerExt0 && isZExtOrSExt (InnerExt0->getOpcode ()) && InnerExt1 &&
2153+ InnerExt0->getOpcode () == InnerExt1->getOpcode ()) {
2154+ Type *InnerExt0Ty =
2155+ InnerExt0->getOperand (0 )->getUnderlyingValue ()->getType ();
2156+ Type *InnerExt1Ty =
2157+ InnerExt1->getOperand (0 )->getUnderlyingValue ()->getType ();
2158+ auto *MaxInnerExtVecTy = cast<VectorType>(ToVectorTy (
2159+ InnerExt0Ty->getIntegerBitWidth () > InnerExt1Ty->getIntegerBitWidth ()
2160+ ? InnerExt0Ty
2161+ : InnerExt1Ty,
2162+ VF));
2163+ bool isUnsigned = InnerExt0->getOpcode () == Instruction::CastOps::ZExt;
2164+ InstructionCost RedCost = Ctx.TTI .getMulAccReductionCost (
2165+ isUnsigned, ElementTy, MaxInnerExtVecTy, CostKind);
2166+ InstructionCost InnerExtCost = Ctx.TTI .getCastInstrCost (
2167+ InnerExt0->getOpcode (), MulTy, MaxInnerExtVecTy,
2168+ TTI::CastContextHint::None, CostKind);
2169+ if (RedCost.isValid () && RedCost < BaseCost + MulCost + 2 * InnerExtCost)
2170+ return RedCost;
2171+ }
2172+ // Match reduce.add(mul)
2173+ InstructionCost RedCost =
2174+ Ctx.TTI .getMulAccReductionCost (true , ElementTy, VectorTy, CostKind);
2175+ if (RedCost.isValid () && RedCost < BaseCost + MulCost)
2176+ return RedCost;
20402177 }
20412178
2042- return Cost + Ctx. TTI . getArithmeticReductionCost (
2043- Opcode, VectorTy, RdxDesc. getFastMathFlags () , CostKind);
2179+ // Normal cost = Reduction cost + BinOp cost
2180+ return BaseCost + Ctx. TTI . getArithmeticInstrCost ( Opcode, ElementTy , CostKind);
20442181}
20452182
20462183#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
0 commit comments