Skip to content

Conversation

@davemgreen
Copy link
Collaborator

@davemgreen davemgreen commented Feb 4, 2025

We already have cost model code for detecting extending mull multiplies for the form mul(ext, ext). Since it was added the codegen for mull has been improved, this attempts to catch the cost model up.

The main idea is to incorporate extends of larger sizes. A vector v8i32 mul(zext(v8i8), zext(v8i8)) will be code-generated as zext (v8i16 mul(zext(v8i8), zext(v8i8)), or umull+ushll+ushll2.

So the total cost should be 3ish if each instruction costs 1. Where exactly we attribute the costs is dependable, this patch opts to sets the cost of the extend to 0 (or the cost of the extend not included in the mull) and the mul gets the cost of the mull+extra extends.

isWideningInstruction is split into two functions for the two types of operands it supports. isSingleExtWideningInstruction now handles addw instructions that extend the second operand, isBinExtWideningInstruction is for instructions like addl that extend both operands.

@llvmbot llvmbot added backend:AArch64 llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Feb 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: David Green (davemgreen)

Changes

We already have cost model code for detecting extending mull multiplies for the form mul(ext, ext). Since it was added the codegen for mull has been improved, this attempts to catch the cost model up.

The main idea is to incorporate extends of larger sizes. A vector v8i32 mul(zext(v8i8), zext(v8i8)) will be code-generated as zext (v8i16 mul(zext(v8i8), zext(v8i8)), or umull+ushll+ushll2.

So the total cost should be 3ish if each instruction costs 1. Where exactly we attribute the costs is dependable, this patch opts to sets the cost of the extend to 0 (or the cost of the extend not included in the mull) and the mul gets the cost of the mull+extra extends.

isWideningInstruction is split into two functions for the two types of operands it supports. isSingleExtWideningInstruction now handles addw instructions that extend the second operand, isBinExtWideningInstruction is for instructions like addl that extend both operands.

The changes in the partial reduction tests show that they need a better cost model, that treats the mul + extends as free for the dot. We should fix that first.


Patch is 266.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125651.diff

7 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+118-40)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+11-3)
  • (modified) llvm/test/Analysis/CostModel/AArch64/arith-widening.ll (+126-126)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/fully-unrolled-cost.ll (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll (+652-168)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll (+10-3)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+348-623)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index aae2fdaf5bec37..31cc6ed8f1a665 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -2585,9 +2585,9 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
   llvm_unreachable("Unsupported register kind");
 }
 
-bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
-                                           ArrayRef<const Value *> Args,
-                                           Type *SrcOverrideTy) {
+bool AArch64TTIImpl::isSingleExtWideningInstruction(
+    unsigned Opcode, Type *DstTy, ArrayRef<const Value *> Args,
+    Type *SrcOverrideTy) {
   // A helper that returns a vector type from the given type. The number of
   // elements in type Ty determines the vector width.
   auto toVectorTy = [&](Type *ArgTy) {
@@ -2605,48 +2605,29 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
       (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
     return false;
 
-  // Determine if the operation has a widening variant. We consider both the
-  // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
-  // instructions.
-  //
-  // TODO: Add additional widening operations (e.g., shl, etc.) once we
-  //       verify that their extending operands are eliminated during code
-  //       generation.
   Type *SrcTy = SrcOverrideTy;
   switch (Opcode) {
-  case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
-  case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
+  case Instruction::Add:   // UADDW(2), SADDW(2).
+  case Instruction::Sub: { // USUBW(2), SSUBW(2).
     // The second operand needs to be an extend
     if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
       if (!SrcTy)
         SrcTy =
             toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
-    } else
+      break;
+    }
+
+    if (Opcode == Instruction::Sub)
       return false;
-    break;
-  case Instruction::Mul: { // SMULL(2), UMULL(2)
-    // Both operands need to be extends of the same type.
-    if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
-        (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
+
+    // UADDW(2), SADDW(2) can be commutted.
+    if (isa<SExtInst>(Args[0]) || isa<ZExtInst>(Args[0])) {
       if (!SrcTy)
         SrcTy =
             toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
-    } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
-      // If one of the operands is a Zext and the other has enough zero bits to
-      // be treated as unsigned, we can still general a umull, meaning the zext
-      // is free.
-      KnownBits Known =
-          computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
-      if (Args[0]->getType()->getScalarSizeInBits() -
-              Known.Zero.countLeadingOnes() >
-          DstTy->getScalarSizeInBits() / 2)
-        return false;
-      if (!SrcTy)
-        SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
-                                           DstTy->getScalarSizeInBits() / 2));
-    } else
-      return false;
-    break;
+      break;
+    }
+    return false;
   }
   default:
     return false;
@@ -2677,6 +2658,73 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
 }
 
+Type *AArch64TTIImpl::isBinExtWideningInstruction(unsigned Opcode, Type *DstTy,
+                                                  ArrayRef<const Value *> Args,
+                                                  Type *SrcOverrideTy) {
+  if (Opcode != Instruction::Add && Opcode != Instruction::Sub &&
+      Opcode != Instruction::Mul)
+    return nullptr;
+
+  // Exit early if DstTy is not a vector type whose elements are one of [i16,
+  // i32, i64]. SVE doesn't generally have the same set of instructions to
+  // perform an extend with the add/sub/mul. There are SMULLB style
+  // instructions, but they operate on top/bottom, requiring some sort of lane
+  // interleaving to be used with zext/sext.
+  unsigned DstEltSize = DstTy->getScalarSizeInBits();
+  if (!useNeonVector(DstTy) || Args.size() != 2 ||
+      (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
+    return nullptr;
+
+  auto getScalarSizeWithOverride = [&](const Value *V) {
+    if (SrcOverrideTy)
+      return SrcOverrideTy->getScalarSizeInBits();
+    return cast<Instruction>(V)
+        ->getOperand(0)
+        ->getType()
+        ->getScalarSizeInBits();
+  };
+
+  unsigned MaxEltSize = 0;
+  if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
+      (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
+    unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
+    unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
+    MaxEltSize = std::max(EltSize0, EltSize1);
+  } else if (isa<SExtInst, ZExtInst>(Args[0]) &&
+             isa<SExtInst, ZExtInst>(Args[1])) {
+    unsigned EltSize0 = getScalarSizeWithOverride(Args[0]);
+    unsigned EltSize1 = getScalarSizeWithOverride(Args[1]);
+    // mul(sext, zext) will become smull(sext, zext) if the extends are large
+    // enough.
+    if (EltSize0 >= DstEltSize / 2 || EltSize1 >= DstEltSize / 2)
+      return nullptr;
+    MaxEltSize = DstEltSize / 2;
+  } else if (Opcode == Instruction::Mul &&
+             (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1]))) {
+    // If one of the operands is a Zext and the other has enough zero bits
+    // to be treated as unsigned, we can still generate a umull, meaning the
+    // zext is free.
+    KnownBits Known =
+        computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
+    if (Args[0]->getType()->getScalarSizeInBits() -
+            Known.Zero.countLeadingOnes() >
+        DstTy->getScalarSizeInBits() / 2)
+      return nullptr;
+
+    MaxEltSize =
+        getScalarSizeWithOverride(isa<ZExtInst>(Args[0]) ? Args[0] : Args[1]);
+  } else
+    return nullptr;
+
+  if (MaxEltSize * 2 > DstEltSize)
+    return nullptr;
+
+  Type *ExtTy = DstTy->getWithNewBitWidth(MaxEltSize * 2);
+  if (ExtTy->getPrimitiveSizeInBits() <= 64)
+    return nullptr;
+  return ExtTy;
+}
+
 // s/urhadd instructions implement the following pattern, making the
 // extends free:
 //   %x = add ((zext i8 -> i16), 1)
@@ -2737,7 +2785,24 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   if (I && I->hasOneUser()) {
     auto *SingleUser = cast<Instruction>(*I->user_begin());
     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
-    if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
+    if (Type *ExtTy = isBinExtWideningInstruction(
+            SingleUser->getOpcode(), Dst, Operands,
+            Src != I->getOperand(0)->getType() ? Src : nullptr)) {
+      // The cost from Src->Src*2 needs to be added if required, the cost from
+      // Src*2->ExtTy is free.
+      if (ExtTy->getScalarSizeInBits() > Src->getScalarSizeInBits() * 2) {
+        Type *DoubleSrcTy =
+            Src->getWithNewBitWidth(Src->getScalarSizeInBits() * 2);
+        return getCastInstrCost(Opcode, DoubleSrcTy, Src,
+                                TTI::CastContextHint::None, CostKind);
+      }
+
+      return 0;
+    }
+
+    if (isSingleExtWideningInstruction(
+            SingleUser->getOpcode(), Dst, Operands,
+            Src != I->getOperand(0)->getType() ? Src : nullptr)) {
       // For adds only count the second operand as free if both operands are
       // extends but not the same operation. (i.e both operands are not free in
       // add(sext, zext)).
@@ -2746,8 +2811,11 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
             (isa<CastInst>(SingleUser->getOperand(1)) &&
              cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
           return 0;
-      } else // Others are free so long as isWideningInstruction returned true.
+      } else {
+        // Others are free so long as isSingleExtWideningInstruction
+        // returned true.
         return 0;
+      }
     }
 
     // The cast will be free for the s/urhadd instructions
@@ -3496,6 +3564,18 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
 
+  // If the operation is a widening instruction (smull or umull) and both
+  // operands are extends the cost can be cheaper by considering that the
+  // operation will operate on the narrowest type size possible (double the
+  // largest input size) and a further extend.
+  if (Type *ExtTy = isBinExtWideningInstruction(Opcode, Ty, Args)) {
+    if (ExtTy != Ty)
+      return getArithmeticInstrCost(Opcode, ExtTy, CostKind) +
+             getCastInstrCost(Instruction::ZExt, Ty, ExtTy,
+                              TTI::CastContextHint::None, CostKind);
+    return LT.first;
+  }
+
   switch (ISD) {
   default:
     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
@@ -3613,10 +3693,8 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
     // - two 2-cost i64 inserts, and
     // - two 1-cost muls.
     // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
-    // LT.first = 2 the cost is 28. If both operands are extensions it will not
-    // need to scalarize so the cost can be cheaper (smull or umull).
-    // so the cost can be cheaper (smull or umull).
-    if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
+    // LT.first = 2 the cost is 28.
+    if (LT.second != MVT::v2i64)
       return LT.first;
     return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() *
            (getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) +
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index b65e3c7a1ab20e..da9e639b802945 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -57,9 +57,17 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     VECTOR_LDST_FOUR_ELEMENTS
   };
 
-  bool isWideningInstruction(Type *DstTy, unsigned Opcode,
-                             ArrayRef<const Value *> Args,
-                             Type *SrcOverrideTy = nullptr);
+  /// Given a add/sub/mul operation, detect a widening addl/subl/mull pattern
+  /// where both operands can be treated like extends. Returns the minimal type
+  /// needed to compute the operation.
+  Type *isBinExtWideningInstruction(unsigned Opcode, Type *DstTy,
+                                    ArrayRef<const Value *> Args,
+                                    Type *SrcOverrideTy = nullptr);
+  /// Given a add/sub operation with a single extend operand, detect a
+  /// widening addw/subw pattern.
+  bool isSingleExtWideningInstruction(unsigned Opcode, Type *DstTy,
+                                      ArrayRef<const Value *> Args,
+                                      Type *SrcOverrideTy = nullptr);
 
   // A helper function called by 'getVectorInstrCost'.
   //
diff --git a/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll b/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll
index 303bcfa289577c..0117299c27c2ee 100644
--- a/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll
+++ b/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll
@@ -325,14 +325,14 @@ define void @extaddv4(<4 x i8> %i8, <4 x i16> %i16, <4 x i32> %i32, <4 x i64> %i
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %azl_16_32 = add <4 x i32> %zl1_16_32, %zl2_16_32
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %sw_16_64 = sext <4 x i16> %i16 to <4 x i64>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %asw_16_64 = add <4 x i64> %i64, %sw_16_64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %sl1_16_64 = sext <4 x i16> %i16 to <4 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %sl2_16_64 = sext <4 x i16> %i16 to <4 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %asl_16_64 = add <4 x i64> %sl1_16_64, %sl2_16_64
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl1_16_64 = sext <4 x i16> %i16 to <4 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl2_16_64 = sext <4 x i16> %i16 to <4 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %asl_16_64 = add <4 x i64> %sl1_16_64, %sl2_16_64
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %zw_16_64 = zext <4 x i16> %i16 to <4 x i64>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %azw_16_64 = add <4 x i64> %i64, %zw_16_64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %zl1_16_64 = zext <4 x i16> %i16 to <4 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %zl2_16_64 = zext <4 x i16> %i16 to <4 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %azl_16_64 = add <4 x i64> %zl1_16_64, %zl2_16_64
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl1_16_64 = zext <4 x i16> %i16 to <4 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl2_16_64 = zext <4 x i16> %i16 to <4 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %azl_16_64 = add <4 x i64> %zl1_16_64, %zl2_16_64
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sw_32_64 = sext <4 x i32> %i32 to <4 x i64>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %asw_32_64 = add <4 x i64> %i64, %sw_32_64
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl1_32_64 = sext <4 x i32> %i32 to <4 x i64>
@@ -434,24 +434,24 @@ define void @extaddv8(<8 x i8> %i8, <8 x i16> %i16, <8 x i32> %i32, <8 x i64> %i
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %azl_8_16 = add <8 x i16> %zl1_8_16, %zl2_8_16
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %sw_8_32 = sext <8 x i8> %i8 to <8 x i32>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %asw_8_32 = add <8 x i32> %i32, %sw_8_32
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %sl1_8_32 = sext <8 x i8> %i8 to <8 x i32>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %sl2_8_32 = sext <8 x i8> %i8 to <8 x i32>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %asl_8_32 = add <8 x i32> %sl1_8_32, %sl2_8_32
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl1_8_32 = sext <8 x i8> %i8 to <8 x i32>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl2_8_32 = sext <8 x i8> %i8 to <8 x i32>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %asl_8_32 = add <8 x i32> %sl1_8_32, %sl2_8_32
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %zw_8_32 = zext <8 x i8> %i8 to <8 x i32>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %azw_8_32 = add <8 x i32> %i32, %zw_8_32
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %zl1_8_32 = zext <8 x i8> %i8 to <8 x i32>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %zl2_8_32 = zext <8 x i8> %i8 to <8 x i32>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %azl_8_32 = add <8 x i32> %zl1_8_32, %zl2_8_32
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_32 = zext <8 x i8> %i8 to <8 x i32>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl2_8_32 = zext <8 x i8> %i8 to <8 x i32>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 3 for instruction: %azl_8_32 = add <8 x i32> %zl1_8_32, %zl2_8_32
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %sw_8_64 = sext <8 x i8> %i8 to <8 x i64>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %asw_8_64 = add <8 x i64> %i64, %sw_8_64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %sl1_8_64 = sext <8 x i8> %i8 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %sl2_8_64 = sext <8 x i8> %i8 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %asl_8_64 = add <8 x i64> %sl1_8_64, %sl2_8_64
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl1_8_64 = sext <8 x i8> %i8 to <8 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl2_8_64 = sext <8 x i8> %i8 to <8 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %asl_8_64 = add <8 x i64> %sl1_8_64, %sl2_8_64
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %zw_8_64 = zext <8 x i8> %i8 to <8 x i64>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %azw_8_64 = add <8 x i64> %i64, %zw_8_64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %zl1_8_64 = zext <8 x i8> %i8 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %zl2_8_64 = zext <8 x i8> %i8 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %azl_8_64 = add <8 x i64> %zl1_8_64, %zl2_8_64
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_64 = zext <8 x i8> %i8 to <8 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl2_8_64 = zext <8 x i8> %i8 to <8 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 7 for instruction: %azl_8_64 = add <8 x i64> %zl1_8_64, %zl2_8_64
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sw_16_32 = sext <8 x i16> %i16 to <8 x i32>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %asw_16_32 = add <8 x i32> %i32, %sw_16_32
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl1_16_32 = sext <8 x i16> %i16 to <8 x i32>
@@ -464,14 +464,14 @@ define void @extaddv8(<8 x i8> %i8, <8 x i16> %i16, <8 x i32> %i32, <8 x i64> %i
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %azl_16_32 = add <8 x i32> %zl1_16_32, %zl2_16_32
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %sw_16_64 = sext <8 x i16> %i16 to <8 x i64>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %asw_16_64 = add <8 x i64> %i64, %sw_16_64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %sl1_16_64 = sext <8 x i16> %i16 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %sl2_16_64 = sext <8 x i16> %i16 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %asl_16_64 = add <8 x i64> %sl1_16_64, %sl2_16_64
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl1_16_64 = sext <8 x i16> %i16 to <8 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %sl2_16_64 = sext <8 x i16> %i16 to <8 x i64>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %asl_16_64 = add <8 x i64> %sl1_16_64, %sl2_16_64
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %zw_16_64 = zext <8 x i16> %i16 to <8 x i64>
 ; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %azw_16_64 = add <8 x i64> %i64, %zw_16_64
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %zl1_16_64 = zext <8 x i16> %i16 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 6 for instruction: %zl2_16_64 = zext <8 x i16> %i16 to <8 x i64>
-; CHECK-NEXT:  Cost Model: Found an estimated cost of 4 for instruction: %azl_16_64 = add <8 x i64> %zl1_16_64, %zl2_16_64
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl1_16_64 = zext <8 x i16> %i16 to <8 x i64>
+; CHECK-NEXT:  Cost Model...
[truncated]

; CHECK-MAXBW-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 8 x i8>, ptr [[TMP15]], align 1
; CHECK-MAXBW-NEXT: [[TMP20:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD4]] to <vscale x 8 x i32>
; CHECK-MAXBW-NEXT: [[TMP22:%.*]] = mul <vscale x 8 x i32> [[TMP20]], [[TMP13]]
; CHECK-MAXBW-NEXT: [[PARTIAL_REDUCE5]] = call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv8i32(<vscale x 2 x i32> [[VEC_PHI1]], <vscale x 8 x i32> [[TMP22]])
Copy link
Collaborator

@SamTebbs33 SamTebbs33 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR that bundles partial reductions inside VPExpressionRecipe, so that will end up hiding the cost of the extends and mul, hopefully reverting this change. That will only be the case if the partial reduction actually gets created, of course. Do your changes prevent the creation of partial reductions or does the LV just end up choosing a VF that doesn't have them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep this has been waiting a while as it shouldn't be changing this. I think it's the second - that in this case it gets it to pick a VF that doesn't have dot, as it just lowers the cost of fixed-width.

Copy link
Collaborator

@SamTebbs33 SamTebbs33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a great improvement to me!

We already have cost model code for detecting extending mull multiplies for the
form `mul(ext, ext)`. Since it was added the codegen for mull has been
improved, this attempts to catch the cost model up.

The main idea is to incorporate extends of larger sizes. A vector
`v8i32 mul(zext(v8i8), zext(v8i8))` will be code-generated as
`zext (v8i16 mul(zext(v8i8), zext(v8i8))`, or ushll+ushll2+umull.

So the total cost should be 3ish if each instruction costs 1. Where exactly we
attribute the costs is dependable, this patch opts to sets the cost of the
extend to 0 (or the cost of the extend not included in the mull) and the mul
gets the cost of the mull+extra extends.

isWideningInstruction is split into two functions for the two types of operands
it supports. isSingleExtWideningInstruction now handles addw instructions that
extend the second operand, isBinExtWideningInstruction is for instructions like
addl that extend both operands.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:AArch64 llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants