Skip to content

Conversation

@davemgreen
Copy link
Collaborator

This splits the existing post-legalize lowering of vector umull/smull into two parts - one to perform the optimization of mul(ext,ext) -> mull and one to perform the v2i64 mul scalarization. The mull part is moved to post legalizer combine and has been taught a few extra tricks from SDAG, using known bits to convert mul(sext, zext) or mul(zext, zero-upper-bits) into umull. This can be important to prevent v2i64 scalarization of muls.

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-aarch64

Author: David Green (davemgreen)

Changes

This splits the existing post-legalize lowering of vector umull/smull into two parts - one to perform the optimization of mul(ext,ext) -> mull and one to perform the v2i64 mul scalarization. The mull part is moved to post legalizer combine and has been taught a few extra tricks from SDAG, using known bits to convert mul(sext, zext) or mul(zext, zero-upper-bits) into umull. This can be important to prevent v2i64 scalarization of muls.


Patch is 30.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112405.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64Combine.td (+14-6)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp (+103)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp (+10-53)
  • (modified) llvm/test/CodeGen/AArch64/aarch64-smull.ll (+55-137)
  • (modified) llvm/test/CodeGen/AArch64/neon-extmul.ll (+19-99)
diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td
index ead6455ddd5278..bc6cc97e027645 100644
--- a/llvm/lib/Target/AArch64/AArch64Combine.td
+++ b/llvm/lib/Target/AArch64/AArch64Combine.td
@@ -208,11 +208,19 @@ def mul_const : GICombineRule<
   (apply [{ applyAArch64MulConstCombine(*${root}, MRI, B, ${matchinfo}); }])
 >;
 
-def lower_mull : GICombineRule<
-  (defs root:$root),
+def mull_matchdata : GIDefMatchData<"std::tuple<bool, Register, Register>">;
+def extmultomull : GICombineRule<
+  (defs root:$root, mull_matchdata:$matchinfo),
+  (match (wip_match_opcode G_MUL):$root,
+          [{ return matchExtMulToMULL(*${root}, MRI, KB, ${matchinfo}); }]),
+  (apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer, ${matchinfo}); }])
+>;
+
+def lower_mulv2s64 : GICombineRule<
+  (defs root:$root, mull_matchdata:$matchinfo),
   (match (wip_match_opcode G_MUL):$root,
-          [{ return matchExtMulToMULL(*${root}, MRI); }]),
-  (apply [{ applyExtMulToMULL(*${root}, MRI, B, Observer); }])
+          [{ return matchMulv2s64(*${root}, MRI); }]),
+  (apply [{ applyMulv2s64(*${root}, MRI, B, Observer); }])
 >;
 
 def build_vector_to_dup : GICombineRule<
@@ -307,7 +315,7 @@ def AArch64PostLegalizerLowering
                         icmp_lowering, build_vector_lowering,
                         lower_vector_fcmp, form_truncstore,
                         vector_sext_inreg_to_shift,
-                        unmerge_ext_to_unmerge, lower_mull,
+                        unmerge_ext_to_unmerge, lower_mulv2s64,
                         vector_unmerge_lowering, insertelt_nonconst]> {
 }
 
@@ -330,5 +338,5 @@ def AArch64PostLegalizerCombiner
                         select_to_minmax, or_to_bsp, combine_concat_vector,
                         commute_constant_to_rhs,
                         push_freeze_to_prevent_poison_from_propagating,
-                        combine_mul_cmlt, combine_use_vector_truncate]> {
+                        combine_mul_cmlt, combine_use_vector_truncate, extmultomull]> {
 }
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index 28d9f4f50f3883..20b5288e0b1945 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -438,6 +438,109 @@ void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI,
   MI.eraseFromParent();
 }
 
+// Match mul({z/s}ext , {z/s}ext) => {u/s}mull
+bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
+                       GISelKnownBits *KB,
+                       std::tuple<bool, Register, Register> &MatchInfo) {
+  // Get the instructions that defined the source operand
+  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
+  MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
+  unsigned I1Opc = I1->getOpcode();
+  unsigned I2Opc = I2->getOpcode();
+
+  if (!DstTy.isVector() || I1->getNumOperands() < 2 || I2->getNumOperands() < 2)
+    return false;
+
+  auto IsAtLeastDoubleExtend = [&](Register R) {
+    LLT Ty = MRI.getType(R);
+    return DstTy.getScalarSizeInBits() >= Ty.getScalarSizeInBits() * 2;
+  };
+
+  // If the source operands were EXTENDED before, then {U/S}MULL can be used
+  bool IsZExt1 =
+      I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_ANYEXT;
+  bool IsZExt2 =
+      I2Opc == TargetOpcode::G_ZEXT || I2Opc == TargetOpcode::G_ANYEXT;
+  if (IsZExt1 && IsZExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
+      IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
+    get<0>(MatchInfo) = true;
+    get<1>(MatchInfo) = I1->getOperand(1).getReg();
+    get<2>(MatchInfo) = I2->getOperand(1).getReg();
+    return true;
+  }
+
+  bool IsSExt1 =
+      I1Opc == TargetOpcode::G_SEXT || I1Opc == TargetOpcode::G_ANYEXT;
+  bool IsSExt2 =
+      I2Opc == TargetOpcode::G_SEXT || I2Opc == TargetOpcode::G_ANYEXT;
+  if (IsSExt1 && IsSExt2 && IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
+      IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
+    get<0>(MatchInfo) = false;
+    get<1>(MatchInfo) = I1->getOperand(1).getReg();
+    get<2>(MatchInfo) = I2->getOperand(1).getReg();
+    return true;
+  }
+
+  // Select SMULL if we can replace zext with sext.
+  if (KB && ((IsSExt1 && IsZExt2) || (IsZExt1 && IsSExt2)) &&
+      IsAtLeastDoubleExtend(I1->getOperand(1).getReg()) &&
+      IsAtLeastDoubleExtend(I2->getOperand(1).getReg())) {
+    Register ZExtOp =
+        IsZExt1 ? I1->getOperand(1).getReg() : I2->getOperand(1).getReg();
+    if (KB->signBitIsZero(ZExtOp)) {
+      get<0>(MatchInfo) = false;
+      get<1>(MatchInfo) = I1->getOperand(1).getReg();
+      get<2>(MatchInfo) = I2->getOperand(1).getReg();
+      return true;
+    }
+  }
+
+  // Select UMULL if we can replace the other operand with an extend.
+  if (KB && (IsZExt1 || IsZExt2) &&
+      IsAtLeastDoubleExtend(IsZExt1 ? I1->getOperand(1).getReg()
+                                    : I2->getOperand(1).getReg())) {
+    APInt Mask = APInt::getHighBitsSet(DstTy.getScalarSizeInBits(),
+                                       DstTy.getScalarSizeInBits() / 2);
+    Register ZExtOp =
+        IsZExt1 ? MI.getOperand(2).getReg() : MI.getOperand(1).getReg();
+    if (KB->maskedValueIsZero(ZExtOp, Mask)) {
+      get<0>(MatchInfo) = true;
+      get<1>(MatchInfo) = IsZExt1 ? I1->getOperand(1).getReg() : ZExtOp;
+      get<2>(MatchInfo) = IsZExt1 ? ZExtOp : I2->getOperand(1).getReg();
+      return true;
+    }
+  }
+  return false;
+}
+
+void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
+                       MachineIRBuilder &B, GISelChangeObserver &Observer,
+                       std::tuple<bool, Register, Register> &MatchInfo) {
+  assert(MI.getOpcode() == TargetOpcode::G_MUL &&
+         "Expected a G_MUL instruction");
+
+  // Get the instructions that defined the source operand
+  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+  bool IsZExt = get<0>(MatchInfo);
+  Register Src1Reg = get<1>(MatchInfo);
+  Register Src2Reg = get<2>(MatchInfo);
+  LLT Src1Ty = MRI.getType(Src1Reg);
+  LLT Src2Ty = MRI.getType(Src2Reg);
+  LLT HalfDstTy = DstTy.changeElementSize(DstTy.getScalarSizeInBits() / 2);
+  unsigned ExtOpc = IsZExt ? TargetOpcode::G_ZEXT : TargetOpcode::G_SEXT;
+
+  if (Src1Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
+    Src1Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src1Reg}).getReg(0);
+  if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
+    Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);
+
+  B.setInstrAndDebugLoc(MI);
+  B.buildInstr(IsZExt ? AArch64::G_UMULL : AArch64::G_SMULL,
+               {MI.getOperand(0).getReg()}, {Src1Reg, Src2Reg});
+  MI.eraseFromParent();
+}
+
 class AArch64PostLegalizerCombinerImpl : public Combiner {
 protected:
   // TODO: Make CombinerHelper methods const.
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
index b40fe55fdfaf67..e7950826220a3a 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
@@ -1177,68 +1177,25 @@ void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
 // Doing these two matches in one function to ensure that the order of matching
 // will always be the same.
 // Try lowering MUL to MULL before trying to scalarize if needed.
-bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
+bool matchMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI) {
   // Get the instructions that defined the source operand
   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
-  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
-  MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
-
-  if (DstTy.isVector()) {
-    // If the source operands were EXTENDED before, then {U/S}MULL can be used
-    unsigned I1Opc = I1->getOpcode();
-    unsigned I2Opc = I2->getOpcode();
-    if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
-         (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
-        (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
-         MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
-        (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
-         MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
-      return true;
-    }
-    // If result type is v2s64, scalarise the instruction
-    else if (DstTy == LLT::fixed_vector(2, 64)) {
-      return true;
-    }
-  }
-  return false;
+  return DstTy == LLT::fixed_vector(2, 64);
 }
 
-void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
-                       MachineIRBuilder &B, GISelChangeObserver &Observer) {
+void applyMulv2s64(MachineInstr &MI, MachineRegisterInfo &MRI,
+                   MachineIRBuilder &B, GISelChangeObserver &Observer) {
   assert(MI.getOpcode() == TargetOpcode::G_MUL &&
          "Expected a G_MUL instruction");
 
   // Get the instructions that defined the source operand
   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
-  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
-  MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
-
-  // If the source operands were EXTENDED before, then {U/S}MULL can be used
-  unsigned I1Opc = I1->getOpcode();
-  unsigned I2Opc = I2->getOpcode();
-  if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
-       (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
-      (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
-       MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
-      (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
-       MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
-
-    B.setInstrAndDebugLoc(MI);
-    B.buildInstr(I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UMULL
-                                                         : AArch64::G_SMULL,
-                 {MI.getOperand(0).getReg()},
-                 {I1->getOperand(1).getReg(), I2->getOperand(1).getReg()});
-    MI.eraseFromParent();
-  }
-  // If result type is v2s64, scalarise the instruction
-  else if (DstTy == LLT::fixed_vector(2, 64)) {
-    LegalizerHelper Helper(*MI.getMF(), Observer, B);
-    B.setInstrAndDebugLoc(MI);
-    Helper.fewerElementsVector(
-        MI, 0,
-        DstTy.changeElementCount(
-            DstTy.getElementCount().divideCoefficientBy(2)));
-  }
+  assert(DstTy == LLT::fixed_vector(2, 64) && "Expected v2s64 Mul");
+  LegalizerHelper Helper(*MI.getMF(), Observer, B);
+  B.setInstrAndDebugLoc(MI);
+  Helper.fewerElementsVector(
+      MI, 0,
+      DstTy.changeElementCount(DstTy.getElementCount().divideCoefficientBy(2)));
 }
 
 class AArch64PostLegalizerLoweringImpl : public Combiner {
diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
index d677526bab0005..201c05f624f214 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
@@ -82,14 +82,10 @@ define <8 x i32> @smull_zext_v8i8_v8i32(ptr %A, ptr %B) nounwind {
 ; CHECK-GI-LABEL: smull_zext_v8i8_v8i32:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    ldr d0, [x0]
-; CHECK-GI-NEXT:    ldr q1, [x1]
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    ushll v2.4s, v0.4h, #0
-; CHECK-GI-NEXT:    ushll2 v3.4s, v0.8h, #0
-; CHECK-GI-NEXT:    sshll v0.4s, v1.4h, #0
-; CHECK-GI-NEXT:    sshll2 v1.4s, v1.8h, #0
-; CHECK-GI-NEXT:    mul v0.4s, v2.4s, v0.4s
-; CHECK-GI-NEXT:    mul v1.4s, v3.4s, v1.4s
+; CHECK-GI-NEXT:    ldr q2, [x1]
+; CHECK-GI-NEXT:    ushll v1.8h, v0.8b, #0
+; CHECK-GI-NEXT:    smull v0.4s, v1.4h, v2.4h
+; CHECK-GI-NEXT:    smull2 v1.4s, v1.8h, v2.8h
 ; CHECK-GI-NEXT:    ret
   %load.A = load <8 x i8>, ptr %A
   %load.B = load <8 x i16>, ptr %B
@@ -121,14 +117,10 @@ define <8 x i32> @smull_zext_v8i8_v8i32_sext_first_operand(ptr %A, ptr %B) nounw
 ; CHECK-GI-LABEL: smull_zext_v8i8_v8i32_sext_first_operand:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    ldr d0, [x1]
-; CHECK-GI-NEXT:    ldr q1, [x0]
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    sshll v2.4s, v1.4h, #0
-; CHECK-GI-NEXT:    sshll2 v1.4s, v1.8h, #0
-; CHECK-GI-NEXT:    ushll v3.4s, v0.4h, #0
-; CHECK-GI-NEXT:    ushll2 v4.4s, v0.8h, #0
-; CHECK-GI-NEXT:    mul v0.4s, v2.4s, v3.4s
-; CHECK-GI-NEXT:    mul v1.4s, v1.4s, v4.4s
+; CHECK-GI-NEXT:    ldr q2, [x0]
+; CHECK-GI-NEXT:    ushll v1.8h, v0.8b, #0
+; CHECK-GI-NEXT:    smull v0.4s, v2.4h, v1.4h
+; CHECK-GI-NEXT:    smull2 v1.4s, v2.8h, v1.8h
 ; CHECK-GI-NEXT:    ret
   %load.A = load <8 x i16>, ptr %A
   %load.B = load <8 x i8>, ptr %B
@@ -318,16 +310,7 @@ define <2 x i64> @smull_zext_and_v2i32_v2i64(ptr %A, ptr %B) nounwind {
 ; CHECK-GI-NEXT:    ldr d1, [x0]
 ; CHECK-GI-NEXT:    and v0.8b, v1.8b, v0.8b
 ; CHECK-GI-NEXT:    ldr d1, [x1]
-; CHECK-GI-NEXT:    sshll v1.2d, v1.2s, #0
-; CHECK-GI-NEXT:    ushll v0.2d, v0.2s, #0
-; CHECK-GI-NEXT:    fmov x9, d1
-; CHECK-GI-NEXT:    mov x11, v1.d[1]
-; CHECK-GI-NEXT:    fmov x8, d0
-; CHECK-GI-NEXT:    mov x10, v0.d[1]
-; CHECK-GI-NEXT:    mul x8, x8, x9
-; CHECK-GI-NEXT:    mul x9, x10, x11
-; CHECK-GI-NEXT:    mov v0.d[0], x8
-; CHECK-GI-NEXT:    mov v0.d[1], x9
+; CHECK-GI-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-GI-NEXT:    ret
   %load.A = load <2 x i32>, ptr %A
   %and.A = and <2 x i32> %load.A, <i32 u0x7FFFFFFF, i32 u0x7FFFFFFF>
@@ -1076,8 +1059,8 @@ define <8 x i16> @umull_extvec_v8i8_v8i16(<8 x i8> %arg) nounwind {
 ; CHECK-GI-LABEL: umull_extvec_v8i8_v8i16:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    movi v1.8h, #12
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-GI-NEXT:    xtn v1.8b, v1.8h
+; CHECK-GI-NEXT:    umull v0.8h, v0.8b, v1.8b
 ; CHECK-GI-NEXT:    ret
   %tmp3 = zext <8 x i8> %arg to <8 x i16>
   %tmp4 = mul <8 x i16> %tmp3, <i16 12, i16 12, i16 12, i16 12, i16 12, i16 12, i16 12, i16 12>
@@ -1132,9 +1115,9 @@ define <4 x i32> @umull_extvec_v4i16_v4i32(<4 x i16> %arg) nounwind {
 ; CHECK-GI-LABEL: umull_extvec_v4i16_v4i32:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    adrp x8, .LCPI39_0
-; CHECK-GI-NEXT:    ushll v0.4s, v0.4h, #0
 ; CHECK-GI-NEXT:    ldr q1, [x8, :lo12:.LCPI39_0]
-; CHECK-GI-NEXT:    mul v0.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT:    xtn v1.4h, v1.4s
+; CHECK-GI-NEXT:    umull v0.4s, v0.4h, v1.4h
 ; CHECK-GI-NEXT:    ret
   %tmp3 = zext <4 x i16> %arg to <4 x i32>
   %tmp4 = mul <4 x i32> %tmp3, <i32 1234, i32 1234, i32 1234, i32 1234>
@@ -1159,16 +1142,9 @@ define <2 x i64> @umull_extvec_v2i32_v2i64(<2 x i32> %arg) nounwind {
 ; CHECK-GI-LABEL: umull_extvec_v2i32_v2i64:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    adrp x8, .LCPI40_0
-; CHECK-GI-NEXT:    ushll v0.2d, v0.2s, #0
 ; CHECK-GI-NEXT:    ldr q1, [x8, :lo12:.LCPI40_0]
-; CHECK-GI-NEXT:    fmov x8, d0
-; CHECK-GI-NEXT:    fmov x9, d1
-; CHECK-GI-NEXT:    mov x10, v0.d[1]
-; CHECK-GI-NEXT:    mov x11, v1.d[1]
-; CHECK-GI-NEXT:    mul x8, x8, x9
-; CHECK-GI-NEXT:    mul x9, x10, x11
-; CHECK-GI-NEXT:    mov v0.d[0], x8
-; CHECK-GI-NEXT:    mov v0.d[1], x9
+; CHECK-GI-NEXT:    xtn v1.2s, v1.2d
+; CHECK-GI-NEXT:    umull v0.2d, v0.2s, v1.2s
 ; CHECK-GI-NEXT:    ret
   %tmp3 = zext <2 x i32> %arg to <2 x i64>
   %tmp4 = mul <2 x i64> %tmp3, <i64 1234, i64 1234>
@@ -1193,9 +1169,9 @@ define <8 x i16> @amull_extvec_v8i8_v8i16(<8 x i8> %arg) nounwind {
 ; CHECK-GI-LABEL: amull_extvec_v8i8_v8i16:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    movi v1.8h, #12
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
 ; CHECK-GI-NEXT:    movi v2.2d, #0xff00ff00ff00ff
-; CHECK-GI-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-GI-NEXT:    xtn v1.8b, v1.8h
+; CHECK-GI-NEXT:    umull v0.8h, v0.8b, v1.8b
 ; CHECK-GI-NEXT:    and v0.16b, v0.16b, v2.16b
 ; CHECK-GI-NEXT:    ret
   %tmp3 = zext <8 x i8> %arg to <8 x i16>
@@ -1226,10 +1202,10 @@ define <4 x i32> @amull_extvec_v4i16_v4i32(<4 x i16> %arg) nounwind {
 ; CHECK-GI-LABEL: amull_extvec_v4i16_v4i32:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    adrp x8, .LCPI42_0
-; CHECK-GI-NEXT:    ushll v0.4s, v0.4h, #0
 ; CHECK-GI-NEXT:    movi v2.2d, #0x00ffff0000ffff
 ; CHECK-GI-NEXT:    ldr q1, [x8, :lo12:.LCPI42_0]
-; CHECK-GI-NEXT:    mul v0.4s, v0.4s, v1.4s
+; CHECK-GI-NEXT:    xtn v1.4h, v1.4s
+; CHECK-GI-NEXT:    umull v0.4s, v0.4h, v1.4h
 ; CHECK-GI-NEXT:    and v0.16b, v0.16b, v2.16b
 ; CHECK-GI-NEXT:    ret
   %tmp3 = zext <4 x i16> %arg to <4 x i32>
@@ -1260,18 +1236,11 @@ define <2 x i64> @amull_extvec_v2i32_v2i64(<2 x i32> %arg) nounwind {
 ; CHECK-GI-LABEL: amull_extvec_v2i32_v2i64:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    adrp x8, .LCPI43_0
-; CHECK-GI-NEXT:    ushll v0.2d, v0.2s, #0
+; CHECK-GI-NEXT:    movi v2.2d, #0x000000ffffffff
 ; CHECK-GI-NEXT:    ldr q1, [x8, :lo12:.LCPI43_0]
-; CHECK-GI-NEXT:    fmov x8, d0
-; CHECK-GI-NEXT:    fmov x9, d1
-; CHECK-GI-NEXT:    mov x10, v0.d[1]
-; CHECK-GI-NEXT:    mov x11, v1.d[1]
-; CHECK-GI-NEXT:    movi v1.2d, #0x000000ffffffff
-; CHECK-GI-NEXT:    mul x8, x8, x9
-; CHECK-GI-NEXT:    mul x9, x10, x11
-; CHECK-GI-NEXT:    mov v0.d[0], x8
-; CHECK-GI-NEXT:    mov v0.d[1], x9
-; CHECK-GI-NEXT:    and v0.16b, v0.16b, v1.16b
+; CHECK-GI-NEXT:    xtn v1.2s, v1.2d
+; CHECK-GI-NEXT:    umull v0.2d, v0.2s, v1.2s
+; CHECK-GI-NEXT:    and v0.16b, v0.16b, v2.16b
 ; CHECK-GI-NEXT:    ret
   %tmp3 = zext <2 x i32> %arg to <2 x i64>
   %tmp4 = mul <2 x i64> %tmp3, <i64 1234, i64 1234>
@@ -1649,9 +1618,9 @@ define <8 x i16> @umull_and_v8i16(<8 x i8> %src1, <8 x i16> %src2) {
 ; CHECK-GI-LABEL: umull_and_v8i16:
 ; CHECK-GI:       // %bb.0: // %entry
 ; CHECK-GI-NEXT:    movi v2.2d, #0xff00ff00ff00ff
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
 ; CHECK-GI-NEXT:    and v1.16b, v1.16b, v2.16b
-; CHECK-GI-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-GI-NEXT:    xtn v1.8b, v1.8h
+; CHECK-GI-NEXT:    umull v0.8h, v0.8b, v1.8b
 ; CHECK-GI-NEXT:    ret
 entry:
   %in1 = zext <8 x i8> %src1 to <8 x i16>
@@ -1678,9 +1647,9 @@ define <8 x i16> @umull_and_v8i16_c(<8 x i8> %src1, <8 x i16> %src2) {
 ; CHECK-GI-LABEL: umull_and_v8i16_c:
 ; CHECK-GI:       // %bb.0: // %entry
 ; CHECK-GI-NEXT:    movi v2.2d, #0xff00ff00ff00ff
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
 ; CHECK-GI-NEXT:    and v1.16b, v1.16b, v2.16b
-; CHECK-GI-NEXT:    mul v0.8h, v1.8h, v0.8h
+; CHECK-GI-NEXT:    xtn v1.8b, v1.8h
+; CHECK-GI-NEXT:    umull v0.8h, v1.8b, v0.8b
 ; CHECK-GI-NEXT:    ret
 entry:
   %in1 = zext <8 x i8> %src1 to <8 x i16>
@@ -1720,8 +1689,8 @@ define <8 x i16> @umull_andconst_v8i16(<8 x i8> %src1, <8 x i16> %src2) {
 ; CHECK-GI-LABEL: umull_andconst_v8i16:
 ; CHECK-GI:       // %bb.0: // %entry
 ; CHECK-GI-NEXT:    movi v1.2d, #0xff00ff00ff00ff
-; CHECK-GI-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-GI-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-GI-NEXT:    xtn v1.8b, v1.8h
+; CHECK-GI-NEXT:    umull v0.8h, v0.8b, v1.8b
 ; CHECK-GI-NEXT:    ret
 entry:
   %in1 = zext <8 x i8> %src1 to <8 x i16>
@@ -1765,29 +1734,13 @@ entry:
 }
 
 define <4 x i32> @umull_and_v4i32(<4 x i16> %src1, <4 x i32> %src2) {
-; CHECK-NEON-LABEL: umull_and_v4i32:
-; CHECK-NEON:       // %bb.0: // %entry
-; CHECK-NEON-NEXT:    movi v2.2d, #0x0000ff000000ff
-; CHECK-NEON-NEXT:    and v1.16b, v1.16b, v2.16b
-; CHECK-NEON-NEXT:    xtn v1.4h, v1.4s
-; CHECK-NEON-NEXT:    umull v0.4s, v0.4h, v1.4h
-; CHECK-NEON-NEXT:    ret
-;
-; CHECK-SVE-LABEL: umull_and_v4i32:
-; CHECK-SVE:       // %bb.0: // %entry
-; CHECK-SVE-NEXT:    movi v2.2d, #0x0000ff000000ff
-; CHECK-SVE-NEXT:    and v1.16b, v1.16b, v2.16b
-; CHECK-SVE-NEXT:    xtn v1.4h, v1.4s
-; CHECK-SVE-NEXT:    umull v0.4s, v0.4h, v1.4h
-; CHECK-SVE-NEXT:    ret
-;
-; CHECK-GI-LABEL: umull_and_v4i32:
-; CHECK-GI:       // %bb.0: // %entry
-; CHECK-GI-NEXT:    movi v2.2d, #0x0000ff000000ff
-; CHECK-GI-NEXT:    ushll v0.4s, v0.4h, #0
-; CHECK-GI-NEXT:    and v1.16b, v1.16b, v2.16b
-; CHECK-GI-NEXT:    mul v0.4s, v0.4s, v1.4s
-; CHECK-GI-NEXT...
[truncated]

def mull_matchdata : GIDefMatchData<"std::tuple<bool, Register, Register>">;
def extmultomull : GICombineRule<
(defs root:$root, mull_matchdata:$matchinfo),
(match (wip_match_opcode G_MUL):$root,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please avoid using wip_match_opcode. I tried to get rid of it from this file for most of the things so introducing more would be a step backwards :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we only want to check the opcode, as the operand info is passed between the match and the apply through a matchinfo. I'm not sure I understand why we want to have to specify the operands if they are never used, but I've switched it over.

def extmultomull : GICombineRule<
(defs root:$root, mull_matchdata:$matchinfo),
(match (G_MUL $dst, $src1, $src2):$root,
[{ return matchExtMulToMULL(*${root}, MRI, KB, ${matchinfo}); }]),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you take a look at push_sub_through_zext, which is at the top of this file?

@davemgreen
Copy link
Collaborator Author

I've updated this to be more like the SDAG equivalent in selectUmullSmull, where it looks at the known buts / sign bits for converting to umull/smull, especially for 64bit types.

Copy link
Collaborator Author

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase and ping - thanks.

if (Src2Ty.getScalarSizeInBits() * 2 != DstTy.getScalarSizeInBits())
Src2Reg = B.buildExtOrTrunc(ExtOpc, {HalfDstTy}, {Src2Reg}).getReg(0);

B.setInstrAndDebugLoc(MI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is pre-set before any apply action now

This splits the existing post-legalize lowering of vector umull/smull into two
parts - one to perform the optimization of mul(ext,ext) -> mull and one to
perform the v2i64 mul scalarization. The mull part is moved to post legalizer
combine and has been taught a few extra tricks from SDAG, using known bits to
convert mul(sext, zext) or mul(zext, zero-upper-bits) into umull. This can be
important to prevent v2i64 scalarization of muls.
@davemgreen
Copy link
Collaborator Author

I've rebased over the changes in GISelValueTracking. Thanks

@davemgreen davemgreen merged commit b4017d8 into llvm:main Apr 15, 2025
11 checks passed
@davemgreen davemgreen deleted the gh-gi-mullus branch April 15, 2025 09:10
@kazutakahirata
Copy link
Contributor

@davemgreen I've landed e65faed to fix a warning from this PR. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants