- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[LV] Add support for partial reductions without a binary op #133922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ✅ With the latest revision this PR passed the C/C++ code formatter. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a couple of nits.
        
          
                llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
              
                Outdated
          
            Show resolved
            Hide resolved
        
      | std::optional<unsigned> BinOpc; | ||
| Type *ExtOpTypes[2] = {nullptr}; | ||
|  | ||
| auto collectExtInfo = [&Exts, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto collectExtInfo = [&Exts, | |
| auto CollectExtInfo = [&Exts, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| ret i32 %result | ||
| } | ||
|  | ||
| define i32 @zext_add_reduc_i8_i32(ptr %a) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should those tests go in a new file? There's no dot product in the new tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I've also moved the non-dot-product tests from partial-reduce-dot-product.ll to a new partial-reduce.ll file.
| if (VFMinValue == Scale) | ||
| return Invalid; | ||
| } | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stray new line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| (OpAExtend != OpBExtend && !ST->hasMatMulInt8() && | ||
| !ST->isSVEorStreamingSVEAvailable()))) | ||
| return Invalid; | ||
| assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) && | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation for the interface should probably also be updated, documenting that Opcode and the second the second type are optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, absolutely right! I've tried to amend the documentation for the interface. Please take a look and see if it makes sense.
| /// accumulator). | ||
| struct PartialReductionChain { | ||
| PartialReductionChain(Instruction *Reduction, Instruction *ExtendA, | ||
| Instruction *ExtendB, Instruction *BinOp) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment above needs updating
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| @llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-llvm-transforms Author: David Sherwood (david-arm) ChangesConsider IR such as this: for.body: Conceptually we can vectorise this using partial reductions too, In order to do this I had to teach getScaledReductions that the Patch is 144.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133922.diff 9 Files Affected: 
 diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4835c66a7a3bc..5f3c8ff3bdfb4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1299,9 +1299,21 @@ class TargetTransformInfo {
   /// \return The cost of a partial reduction, which is a reduction from a
   /// vector to another vector with fewer elements of larger size. They are
   /// represented by the llvm.experimental.partial.reduce.add intrinsic, which
-  /// takes an accumulator and a binary operation operand that itself is fed by
-  /// two extends. An example of an operation that uses a partial reduction is a
-  /// dot product, which reduces two vectors to another of 4 times fewer and 4
+  /// takes an accumulator of type \p AccumType and a second vector operand to
+  /// be accumulated, whose element count is specified by \p VF. The type of
+  /// reduction is specified by \p Opcode. The second operand passed to the
+  /// intrinsic could be the result of an extend, such as sext or zext. In
+  /// this case \p BinOp is nullopt, \p InputTypeA represents the type being
+  /// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
+  /// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
+  /// Alternatively, the second operand could be the result of a binary
+  /// operation performed on two extends, i.e.
+  ///   mul(zext i8 %a -> i32, zext i8 %b -> i32).
+  /// In this case \p BinOp may specify the opcode of the binary operation,
+  /// \p InputTypeA and \p InputTypeB the types being extended, and
+  /// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
+  /// operation that uses a partial reduction is a dot product, which reduces
+  /// two vectors in binary mul operation to another of 4 times fewer and 4
   /// times larger elements.
   InstructionCost
   getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 77be41b78bc7f..48424185c68de 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
 
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
-  if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
+  if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
+      OpAExtend == TTI::PR_None)
     return Invalid;
 
-  if (InputTypeA != InputTypeB)
+  // We only support multiply binary operations for now, and for muls we
+  // require the types being extended to be the same.
+  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
+  // only if the i8mm or sve/streaming features are available.
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
+                OpBExtend == TTI::PR_None ||
+                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+                 !ST->isSVEorStreamingSVEAvailable())))
     return Invalid;
+  assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
+         "Unexpected values for OpBExtend or InputTypeB");
 
   EVT InputEVT = EVT::getEVT(InputTypeA);
   EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
   } else
     return Invalid;
 
-  // AArch64 supports lowering mixed extensions to a usdot but only if the
-  // i8mm or sve/streaming features are available.
-  if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-       !ST->isSVEorStreamingSVEAvailable()))
-    return Invalid;
-
-  if (!BinOp || *BinOp != Instruction::Mul)
-    return Invalid;
-
   return Cost;
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0291a8bfd9674..654f3ecacf51b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   // something that isn't another partial reduction. This is because the
   // extends are intended to be lowered along with the reduction itself.
 
-  // Build up a set of partial reduction bin ops for efficient use checking.
-  SmallSet<User *, 4> PartialReductionBinOps;
+  // Build up a set of partial reduction ops for efficient use checking.
+  SmallSet<User *, 4> PartialReductionOps;
   for (const auto &[PartialRdx, _] : PartialReductionChains)
-    PartialReductionBinOps.insert(PartialRdx.BinOp);
+    PartialReductionOps.insert(PartialRdx.ExtendUser);
 
   auto ExtendIsOnlyUsedByPartialReductions =
-      [&PartialReductionBinOps](Instruction *Extend) {
+      [&PartialReductionOps](Instruction *Extend) {
         return all_of(Extend->users(), [&](const User *U) {
-          return PartialReductionBinOps.contains(U);
+          return PartialReductionOps.contains(U);
         });
       };
 
@@ -8782,7 +8782,7 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   for (auto Pair : PartialReductionChains) {
     PartialReductionChain Chain = Pair.first;
     if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
-        ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+        (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
       ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
   }
 }
@@ -8790,7 +8790,6 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
 bool VPRecipeBuilder::getScaledReductions(
     Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
     SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
-
   if (!CM.TheLoop->contains(RdxExitInstr))
     return false;
 
@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
   if (PhiOp != PHI)
     return false;
 
-  auto *BinOp = dyn_cast<BinaryOperator>(Op);
-  if (!BinOp || !BinOp->hasOneUse())
-    return false;
-
   using namespace llvm::PatternMatch;
-  // Use the side-effect of match to replace BinOp only if the pattern is
-  // matched, we don't care at this point whether it actually matched.
-  match(BinOp, m_Neg(m_BinOp(BinOp)));
 
-  Value *A, *B;
-  if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
-      !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
-    return false;
+  // If the update is a binary operator, check both of its operands to see if
+  // they are extends. Otherwise, see if the update comes directly from an
+  // extend.
+  Instruction *Exts[2] = {nullptr};
+  BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
+  std::optional<unsigned> BinOpc;
+  Type *ExtOpTypes[2] = {nullptr};
+
+  auto CollectExtInfo = [&Exts,
+                         &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
+    unsigned I = 0;
+    for (Value *OpI : Ops) {
+      Value *ExtOp;
+      if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
+        return false;
+      Exts[I] = cast<Instruction>(OpI);
+      ExtOpTypes[I] = ExtOp->getType();
+      I++;
+    }
+    return true;
+  };
+
+  if (ExtendUser) {
+    if (!ExtendUser->hasOneUse())
+      return false;
+
+    // Use the side-effect of match to replace BinOp only if the pattern is
+    // matched, we don't care at this point whether it actually matched.
+    match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
 
-  Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
-  Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+    SmallVector<Value *> Ops(ExtendUser->operands());
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    BinOpc = std::make_optional(ExtendUser->getOpcode());
+  } else if (match(Update, m_Add(m_Value(), m_Value()))) {
+    // We already know the operands for Update are Op and PhiOp.
+    SmallVector<Value *> Ops({Op});
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    ExtendUser = Update;
+    BinOpc = std::nullopt;
+  } else
+    return false;
 
   TTI::PartialReductionExtendKind OpAExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+      TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
   TTI::PartialReductionExtendKind OpBExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtB);
-
-  PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+      Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
+              : TargetTransformInfo::PR_None;
+  PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
 
   unsigned TargetScaleFactor =
       PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
-          A->getType()->getPrimitiveSizeInBits());
+          ExtOpTypes[0]->getPrimitiveSizeInBits());
 
   if (LoopVectorizationPlanner::getDecisionAndClampRange(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
-                Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend,
-                std::make_optional(BinOp->getOpcode()));
+                Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
+                PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 334cfbad8bd7c..8d2d187231303 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -26,13 +26,14 @@ struct HistogramInfo;
 struct VFRange;
 
 /// A chain of instructions that form a partial reduction.
-/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
-/// accumulator).
+/// Designed to match either:
+///   reduction_bin_op (extend (A), accumulator), or
+///   reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
 struct PartialReductionChain {
   PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
-                        Instruction *ExtendB, Instruction *BinOp)
-      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
-  }
+                        Instruction *ExtendB, Instruction *ExtendUser)
+      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
+        ExtendUser(ExtendUser) {}
   /// The top-level binary operation that forms the reduction to a scalar
   /// after the loop body.
   Instruction *Reduction;
@@ -40,8 +41,8 @@ struct PartialReductionChain {
   Instruction *ExtendA;
   Instruction *ExtendB;
 
-  /// The binary operation using the extends that is then reduced.
-  Instruction *BinOp;
+  /// The user of the extend that is then reduced.
+  Instruction *ExtendUser;
 };
 
 /// Helper class to create VPRecipies from IR instructions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b16a8fc563f4c..ad27e9435669f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
 InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
-  std::optional<unsigned> Opcode = std::nullopt;
-  VPValue *BinOp = getOperand(0);
+  // If the input operand is an extend then use the opcode for this recipe.
+  std::optional<unsigned> Opcode;
+  VPValue *Op = getOperand(0);
+  VPRecipeBase *OpR = Op->getDefiningRecipe();
 
   // If the partial reduction is predicated, a select will be operand 0 rather
-  // than the binary op
+  // than the extend user.
   using namespace llvm::VPlanPatternMatch;
-  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
-    BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
-  // If BinOp is a negation, use the side effect of match to assign the actual
-  // binary operation to BinOp
-  match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
-  VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
-
-  if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
-    Opcode = std::make_optional(WidenR->getOpcode());
-
-  VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
-  VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
-
-  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
-  auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
-                                                     : BinOpR->getOperand(0));
-  auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
-                                                     : BinOpR->getOperand(1));
+  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
+    Op = OpR->getOperand(1);
+    OpR = Op->getDefiningRecipe();
+  }
 
   auto GetExtendKind = [](VPRecipeBase *R) {
     // The extend could come from outside the plan.
@@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
     return TargetTransformInfo::PR_None;
   };
 
+  Type *InputTypeA, *InputTypeB;
+  TTI::PartialReductionExtendKind ExtAType, ExtBType;
+
+  // The input may come straight from a zext or sext.
+  if (isa<VPWidenCastRecipe>(OpR)) {
+    Opcode = std::nullopt;
+    InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
+    InputTypeB = nullptr;
+    ExtAType = GetExtendKind(OpR);
+    ExtBType = TargetTransformInfo::PR_None;
+  } else {
+    // If BinOp is a negation, use the side effect of match to assign the actual
+    // binary operation to BinOp
+    match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
+    OpR = Op->getDefiningRecipe();
+    Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());
+
+    VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
+    VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();
+
+    InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
+                                                 : OpR->getOperand(0));
+    InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
+                                                 : OpR->getOperand(1));
+    ExtAType = GetExtendKind(ExtAR);
+    ExtBType = GetExtendKind(ExtBR);
+  }
+
+  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
   return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
-                                         PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+                                         PhiType, VF, ExtAType, ExtBType,
+                                         Opcode);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index a229ca8c6e6db..075742ff95b04 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -1030,6 +1030,472 @@ for.body:                                         ; preds = %for.body.preheader,
   br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
 }
 
+
+define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEON-NEXT:  entry:
+; CHECK-NEON-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON:       vector.ph:
+; CHECK-NEON-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-NEON:       vector.body:
+; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP1]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEON-NEXT:    [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEON-NEXT:    [[TMP6:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP3]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEON-NEXT:    [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP9]])
+; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEON-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-NEON:       middle.block:
+; CHECK-NEON-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK-NEON:       scalar.ph:
+;
+; CHECK-SVE-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-SVE-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-SVE-NEXT:  entry:
+; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE:       vector.ph:
+; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE:       vector.body:
+; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]...
[truncated]
 | 
| @llvm/pr-subscribers-llvm-analysis Author: David Sherwood (david-arm) ChangesConsider IR such as this: for.body: Conceptually we can vectorise this using partial reductions too, In order to do this I had to teach getScaledReductions that the Patch is 144.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133922.diff 9 Files Affected: 
 diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 4835c66a7a3bc..5f3c8ff3bdfb4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1299,9 +1299,21 @@ class TargetTransformInfo {
   /// \return The cost of a partial reduction, which is a reduction from a
   /// vector to another vector with fewer elements of larger size. They are
   /// represented by the llvm.experimental.partial.reduce.add intrinsic, which
-  /// takes an accumulator and a binary operation operand that itself is fed by
-  /// two extends. An example of an operation that uses a partial reduction is a
-  /// dot product, which reduces two vectors to another of 4 times fewer and 4
+  /// takes an accumulator of type \p AccumType and a second vector operand to
+  /// be accumulated, whose element count is specified by \p VF. The type of
+  /// reduction is specified by \p Opcode. The second operand passed to the
+  /// intrinsic could be the result of an extend, such as sext or zext. In
+  /// this case \p BinOp is nullopt, \p InputTypeA represents the type being
+  /// extended and \p OpAExtend the operation, i.e. sign- or zero-extend.
+  /// Also, \p InputTypeB should be nullptr and OpBExtend should be None.
+  /// Alternatively, the second operand could be the result of a binary
+  /// operation performed on two extends, i.e.
+  ///   mul(zext i8 %a -> i32, zext i8 %b -> i32).
+  /// In this case \p BinOp may specify the opcode of the binary operation,
+  /// \p InputTypeA and \p InputTypeB the types being extended, and
+  /// \p OpAExtend, \p OpBExtend the form of extensions. An example of an
+  /// operation that uses a partial reduction is a dot product, which reduces
+  /// two vectors in binary mul operation to another of 4 times fewer and 4
   /// times larger elements.
   InstructionCost
   getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 77be41b78bc7f..48424185c68de 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5177,11 +5177,21 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
 
   // Sub opcodes currently only occur in chained cases.
   // Independent partial reduction subtractions are still costed as an add
-  if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
+  if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
+      OpAExtend == TTI::PR_None)
     return Invalid;
 
-  if (InputTypeA != InputTypeB)
+  // We only support multiply binary operations for now, and for muls we
+  // require the types being extended to be the same.
+  // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
+  // only if the i8mm or sve/streaming features are available.
+  if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
+                OpBExtend == TTI::PR_None ||
+                (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
+                 !ST->isSVEorStreamingSVEAvailable())))
     return Invalid;
+  assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
+         "Unexpected values for OpBExtend or InputTypeB");
 
   EVT InputEVT = EVT::getEVT(InputTypeA);
   EVT AccumEVT = EVT::getEVT(AccumType);
@@ -5228,16 +5238,6 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
   } else
     return Invalid;
 
-  // AArch64 supports lowering mixed extensions to a usdot but only if the
-  // i8mm or sve/streaming features are available.
-  if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-       !ST->isSVEorStreamingSVEAvailable()))
-    return Invalid;
-
-  if (!BinOp || *BinOp != Instruction::Mul)
-    return Invalid;
-
   return Cost;
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0291a8bfd9674..654f3ecacf51b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8765,15 +8765,15 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   // something that isn't another partial reduction. This is because the
   // extends are intended to be lowered along with the reduction itself.
 
-  // Build up a set of partial reduction bin ops for efficient use checking.
-  SmallSet<User *, 4> PartialReductionBinOps;
+  // Build up a set of partial reduction ops for efficient use checking.
+  SmallSet<User *, 4> PartialReductionOps;
   for (const auto &[PartialRdx, _] : PartialReductionChains)
-    PartialReductionBinOps.insert(PartialRdx.BinOp);
+    PartialReductionOps.insert(PartialRdx.ExtendUser);
 
   auto ExtendIsOnlyUsedByPartialReductions =
-      [&PartialReductionBinOps](Instruction *Extend) {
+      [&PartialReductionOps](Instruction *Extend) {
         return all_of(Extend->users(), [&](const User *U) {
-          return PartialReductionBinOps.contains(U);
+          return PartialReductionOps.contains(U);
         });
       };
 
@@ -8782,7 +8782,7 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   for (auto Pair : PartialReductionChains) {
     PartialReductionChain Chain = Pair.first;
     if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
-        ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
+        (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
       ScaledReductionMap.insert(std::make_pair(Chain.Reduction, Pair.second));
   }
 }
@@ -8790,7 +8790,6 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
 bool VPRecipeBuilder::getScaledReductions(
     Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
     SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
-
   if (!CM.TheLoop->contains(RdxExitInstr))
     return false;
 
@@ -8819,40 +8818,70 @@ bool VPRecipeBuilder::getScaledReductions(
   if (PhiOp != PHI)
     return false;
 
-  auto *BinOp = dyn_cast<BinaryOperator>(Op);
-  if (!BinOp || !BinOp->hasOneUse())
-    return false;
-
   using namespace llvm::PatternMatch;
-  // Use the side-effect of match to replace BinOp only if the pattern is
-  // matched, we don't care at this point whether it actually matched.
-  match(BinOp, m_Neg(m_BinOp(BinOp)));
 
-  Value *A, *B;
-  if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
-      !match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
-    return false;
+  // If the update is a binary operator, check both of its operands to see if
+  // they are extends. Otherwise, see if the update comes directly from an
+  // extend.
+  Instruction *Exts[2] = {nullptr};
+  BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
+  std::optional<unsigned> BinOpc;
+  Type *ExtOpTypes[2] = {nullptr};
+
+  auto CollectExtInfo = [&Exts,
+                         &ExtOpTypes](SmallVectorImpl<Value *> &Ops) -> bool {
+    unsigned I = 0;
+    for (Value *OpI : Ops) {
+      Value *ExtOp;
+      if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))))
+        return false;
+      Exts[I] = cast<Instruction>(OpI);
+      ExtOpTypes[I] = ExtOp->getType();
+      I++;
+    }
+    return true;
+  };
+
+  if (ExtendUser) {
+    if (!ExtendUser->hasOneUse())
+      return false;
+
+    // Use the side-effect of match to replace BinOp only if the pattern is
+    // matched, we don't care at this point whether it actually matched.
+    match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
 
-  Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
-  Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));
+    SmallVector<Value *> Ops(ExtendUser->operands());
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    BinOpc = std::make_optional(ExtendUser->getOpcode());
+  } else if (match(Update, m_Add(m_Value(), m_Value()))) {
+    // We already know the operands for Update are Op and PhiOp.
+    SmallVector<Value *> Ops({Op});
+    if (!CollectExtInfo(Ops))
+      return false;
+
+    ExtendUser = Update;
+    BinOpc = std::nullopt;
+  } else
+    return false;
 
   TTI::PartialReductionExtendKind OpAExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtA);
+      TargetTransformInfo::getPartialReductionExtendKind(Exts[0]);
   TTI::PartialReductionExtendKind OpBExtend =
-      TargetTransformInfo::getPartialReductionExtendKind(ExtB);
-
-  PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
+      Exts[1] ? TargetTransformInfo::getPartialReductionExtendKind(Exts[1])
+              : TargetTransformInfo::PR_None;
+  PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
 
   unsigned TargetScaleFactor =
       PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
-          A->getType()->getPrimitiveSizeInBits());
+          ExtOpTypes[0]->getPrimitiveSizeInBits());
 
   if (LoopVectorizationPlanner::getDecisionAndClampRange(
           [&](ElementCount VF) {
             InstructionCost Cost = TTI->getPartialReductionCost(
-                Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
-                VF, OpAExtend, OpBExtend,
-                std::make_optional(BinOp->getOpcode()));
+                Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
+                PHI->getType(), VF, OpAExtend, OpBExtend, BinOpc);
             return Cost.isValid();
           },
           Range)) {
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 334cfbad8bd7c..8d2d187231303 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -26,13 +26,14 @@ struct HistogramInfo;
 struct VFRange;
 
 /// A chain of instructions that form a partial reduction.
-/// Designed to match: reduction_bin_op (bin_op (extend (A), (extend (B))),
-/// accumulator).
+/// Designed to match either:
+///   reduction_bin_op (extend (A), accumulator), or
+///   reduction_bin_op (bin_op (extend (A), (extend (B))), accumulator).
 struct PartialReductionChain {
   PartialReductionChain(Instruction *Reduction, Instruction *ExtendA,
-                        Instruction *ExtendB, Instruction *BinOp)
-      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB), BinOp(BinOp) {
-  }
+                        Instruction *ExtendB, Instruction *ExtendUser)
+      : Reduction(Reduction), ExtendA(ExtendA), ExtendB(ExtendB),
+        ExtendUser(ExtendUser) {}
   /// The top-level binary operation that forms the reduction to a scalar
   /// after the loop body.
   Instruction *Reduction;
@@ -40,8 +41,8 @@ struct PartialReductionChain {
   Instruction *ExtendA;
   Instruction *ExtendB;
 
-  /// The binary operation using the extends that is then reduced.
-  Instruction *BinOp;
+  /// The user of the extend that is then reduced.
+  Instruction *ExtendUser;
 };
 
 /// Helper class to create VPRecipies from IR instructions.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index b16a8fc563f4c..ad27e9435669f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -281,31 +281,18 @@ bool VPRecipeBase::isPhi() const {
 InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
-  std::optional<unsigned> Opcode = std::nullopt;
-  VPValue *BinOp = getOperand(0);
+  // If the input operand is an extend then use the opcode for this recipe.
+  std::optional<unsigned> Opcode;
+  VPValue *Op = getOperand(0);
+  VPRecipeBase *OpR = Op->getDefiningRecipe();
 
   // If the partial reduction is predicated, a select will be operand 0 rather
-  // than the binary op
+  // than the extend user.
   using namespace llvm::VPlanPatternMatch;
-  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
-    BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
-  // If BinOp is a negation, use the side effect of match to assign the actual
-  // binary operation to BinOp
-  match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
-  VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
-
-  if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
-    Opcode = std::make_optional(WidenR->getOpcode());
-
-  VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
-  VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
-
-  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
-  auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
-                                                     : BinOpR->getOperand(0));
-  auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
-                                                     : BinOpR->getOperand(1));
+  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue()))) {
+    Op = OpR->getOperand(1);
+    OpR = Op->getDefiningRecipe();
+  }
 
   auto GetExtendKind = [](VPRecipeBase *R) {
     // The extend could come from outside the plan.
@@ -321,9 +308,38 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
     return TargetTransformInfo::PR_None;
   };
 
+  Type *InputTypeA, *InputTypeB;
+  TTI::PartialReductionExtendKind ExtAType, ExtBType;
+
+  // The input may come straight from a zext or sext.
+  if (isa<VPWidenCastRecipe>(OpR)) {
+    Opcode = std::nullopt;
+    InputTypeA = Ctx.Types.inferScalarType(OpR->getOperand(0));
+    InputTypeB = nullptr;
+    ExtAType = GetExtendKind(OpR);
+    ExtBType = TargetTransformInfo::PR_None;
+  } else {
+    // If BinOp is a negation, use the side effect of match to assign the actual
+    // binary operation to BinOp
+    match(Op, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Op)));
+    OpR = Op->getDefiningRecipe();
+    Opcode = std::make_optional(cast<VPWidenRecipe>(OpR)->getOpcode());
+
+    VPRecipeBase *ExtAR = OpR->getOperand(0)->getDefiningRecipe();
+    VPRecipeBase *ExtBR = OpR->getOperand(1)->getDefiningRecipe();
+
+    InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
+                                                 : OpR->getOperand(0));
+    InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
+                                                 : OpR->getOperand(1));
+    ExtAType = GetExtendKind(ExtAR);
+    ExtBType = GetExtendKind(ExtBR);
+  }
+
+  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
   return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
-                                         PhiType, VF, GetExtendKind(ExtAR),
-                                         GetExtendKind(ExtBR), Opcode);
+                                         PhiType, VF, ExtAType, ExtBType,
+                                         Opcode);
 }
 
 void VPPartialReductionRecipe::execute(VPTransformState &State) {
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
index a229ca8c6e6db..075742ff95b04 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -1030,6 +1030,472 @@ for.body:                                         ; preds = %for.body.preheader,
   br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
 }
 
+
+define i32 @chained_partial_reduce_madd_extadd(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEON-NEXT:  entry:
+; CHECK-NEON-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON:       vector.ph:
+; CHECK-NEON-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-NEON:       vector.body:
+; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE3:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[INDEX]]
+; CHECK-NEON-NEXT:    [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP1]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEON-NEXT:    [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEON-NEXT:    [[TMP6:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP3]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEON-NEXT:    [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE:%.*]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP10]])
+; CHECK-NEON-NEXT:    [[PARTIAL_REDUCE3]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[PARTIAL_REDUCE]], <16 x i32> [[TMP9]])
+; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEON-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[TMP12]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP16:![0-9]+]]
+; CHECK-NEON:       middle.block:
+; CHECK-NEON-NEXT:    [[TMP13:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[PARTIAL_REDUCE3]])
+; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+; CHECK-NEON:       scalar.ph:
+;
+; CHECK-SVE-LABEL: define i32 @chained_partial_reduce_madd_extadd(
+; CHECK-SVE-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-SVE-NEXT:  entry:
+; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE:       vector.ph:
+; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE:       vector.body:
+; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]...
[truncated]
 | 
| Gentle ping. :) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious if/how this will interact with #136173?
Would it potentially help to simplify computeCost & co?
| 
 I think we'll want to turn a partial reduction without a bin op into a  | 
| 
 @fhahn @SamTebbs33 Do you think that #136173 is likely to land soon? Ideally I'd like to progress this patch within the next few weeks. It looks like #136173 also depends upon #113903, which looks like it might still take some time. | 
| If both of those PRs are close to landing I could try to check them out and see how my PR interacts with them, but I'd rather do that once I know they're clos(ish) to the finish line. | 
| 
 I think that 136173 is very very close to being approved and there shouldn't be any big changes to it on the way. | 
| Hi, @david-arm asked me to look after this pull request for him. I've rebased it and have added some new logic to  Regarding the previous conversation about waiting for @SamTebbs33's work, we've discussed the situation internally and would like to prioritize getting this patch along so that it's ready before the branch point for 21. @fhahn Could you please re-review this patch? Thank you. | 
| Hi @MDevereau the code that this PR touches is currently being refactored quite significantly. I think it's worth holding off a little longer until #144908 and #144281 make it in, and then rebase this patch on top of those changes. Those patches should also make it easier to model this improvement. | 
| We at Arm have discussed this offline and because this PR improves a key workload and the LLVM 21 branch point is fast approaching we prefer to detach it from the refactoring work Sander mentions above. There is still potential that all PRs could land in time but we just don't want to risk this PR not making the release. @fhahn: Any objections to us prioritising the performance improvement? | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels odd to return the cost of another operation (since it may be different?), but since we're hoping to use the bundle recipe for costing we can fix things there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| 
 I agree with @sdesmalen-arm that it probably would be good to first land the refactorings if possible. But if it looks like the refactorings won't land in time by the middle of next week, lets pull this PR ahead of the other patches? That should hopefully leave enough time before the branch in either case. | 
Consider IR such as this: for.body: %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] %gep.a = getelementptr i8, ptr %a, i64 %iv %load.a = load i8, ptr %gep.a, align 1 %ext.a = zext i8 %load.a to i32 %add = add i32 %ext.a, %accum %iv.next = add i64 %iv, 1 %exitcond.not = icmp eq i64 %iv.next, 1025 br i1 %exitcond.not, label %for.exit, label %for.body Conceptually we can vectorise this using partial reductions too, although the current loop vectoriser implementation requires the accumulation of a multiply. For AArch64 this is easily done with a udot or sdot with an identity operand, i.e. a vector of (i16 1). In order to do this I had to teach getScaledReductions that the accumulated value may come from a unary op, hence there is only one extension to consider. Similarly, I updated the vplan and AArch64 TTI cost model to understand the possible unary op. Co-authored-by: Matt Devereau <[email protected]>
| 
 Works for me. Let's say Wednesday 12pm uk time. That should give us just enough time to recover if post-merge testing uncovers something unexpected. | 
| 
 I don't think there will be many difficult-to-resolve conflicts when rebasing #144908 and #144281 on top of this change since most of this PR affects the early partial reduction detection code. I'm happy to fix the merge conflicts in  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|  | ||
| /// The binary operation using the extends that is then reduced. | ||
| Instruction *BinOp; | ||
| /// The user of the extend that is then reduced. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| /// The user of the extend that is then reduced. | |
| /// The user of the extends that is then reduced. | 
Still for multiple extends, right?
| // Pick out opcode, type/ext information and use sub side effects from a widen | ||
| // recipe. | ||
| auto HandleWiden = [&](VPWidenRecipe *Widen) { | ||
| if (match(Widen, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On top of #146073, would we still need to detect the various patterns here or would/could this be replaced by different ExpressionTypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When applying this patch on top of #146073, most tests pass even with this entire VPlan being deleted. The main exception I've seen is with the test zext_add_reduc_i8_i32_has_neon_dotprod which ends up with an invalid cost as it tries to go through this which isn't passing any information about extends through to getPartialReductionCost
| @@ -0,0 +1,826 @@ | |||
| ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals none --filter-out-after "^scalar.ph:" --version 4 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference to llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-neon.ll? Just no dotprod feature? Would be good to have clearer name, in this test neon is also available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test in partial-reduce.ll has the SVE attribute which means it will generate scalable code for the CHECK-MAXBW tests. I can change this one to partial-reduce-sve.ll.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or rather, it should be possible to delete partial-reduce-neon.ll and instead just use different attributes for each test instead of having a different file
| I am suspecting that this is causing: | 
Consider IR such as this:
for.body:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ]
%accum = phi i32 [ 0, %entry ], [ %add, %for.body ]
%gep.a = getelementptr i8, ptr %a, i64 %iv
%load.a = load i8, ptr %gep.a, align 1
%ext.a = zext i8 %load.a to i32
%add = add i32 %ext.a, %accum
%iv.next = add i64 %iv, 1
%exitcond.not = icmp eq i64 %iv.next, 1025
br i1 %exitcond.not, label %for.exit, label %for.body
Conceptually we can vectorise this using partial reductions too,
although the current loop vectoriser implementation requires the
accumulation of a multiply. For AArch64 this is easily done with
a udot or sdot with an identity operand, i.e. a vector of (i16 1).
In order to do this I had to teach getScaledReductions that the
accumulated value may come from a unary op, hence there is only
one extension to consider. Similarly, I updated the vplan and
AArch64 TTI cost model to understand the possible unary op.