Skip to content

Conversation

@davemgreen
Copy link
Collaborator

@davemgreen davemgreen commented Nov 5, 2024

This attempts to clean up and improve where we generate smull/umull using known-bits. For v2i64 types (where no mul is present), we try to create mull more aggressively to avoid scalarization.

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-backend-aarch64

Author: David Green (davemgreen)

Changes

This attempts to clean up and improve where we generate smull using known-bits. For v2i64 types (where no mul is present), we try to create mull more aggressively to avoid scalarization.


Full diff: https://github.com/llvm/llvm-project/pull/114997.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+23-84)
  • (modified) llvm/test/CodeGen/AArch64/aarch64-smull.ll (+30-62)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0814380b188485..071c13853eafdf 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5171,40 +5171,6 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
   return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
 }
 
-static EVT getExtensionTo64Bits(const EVT &OrigVT) {
-  if (OrigVT.getSizeInBits() >= 64)
-    return OrigVT;
-
-  assert(OrigVT.isSimple() && "Expecting a simple value type");
-
-  MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
-  switch (OrigSimpleTy) {
-  default: llvm_unreachable("Unexpected Vector Type");
-  case MVT::v2i8:
-  case MVT::v2i16:
-     return MVT::v2i32;
-  case MVT::v4i8:
-    return  MVT::v4i16;
-  }
-}
-
-static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
-                                                 const EVT &OrigTy,
-                                                 const EVT &ExtTy,
-                                                 unsigned ExtOpcode) {
-  // The vector originally had a size of OrigTy. It was then extended to ExtTy.
-  // We expect the ExtTy to be 128-bits total. If the OrigTy is less than
-  // 64-bits we need to insert a new extension so that it will be 64-bits.
-  assert(ExtTy.is128BitVector() && "Unexpected extension size");
-  if (OrigTy.getSizeInBits() >= 64)
-    return N;
-
-  // Must extend size to at least 64 bits to be used as an operand for VMULL.
-  EVT NewVT = getExtensionTo64Bits(OrigTy);
-
-  return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
-}
-
 // Returns lane if Op extracts from a two-element vector and lane is constant
 // (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
 static std::optional<uint64_t>
@@ -5250,31 +5216,11 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
 static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
   EVT VT = N.getValueType();
   assert(VT.is128BitVector() && "Unexpected vector MULL size");
-
-  unsigned NumElts = VT.getVectorNumElements();
-  unsigned OrigEltSize = VT.getScalarSizeInBits();
-  unsigned EltSize = OrigEltSize / 2;
-  MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
-
-  APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
-  if (DAG.MaskedValueIsZero(N, HiBits))
-    return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
-
-  if (ISD::isExtOpcode(N.getOpcode()))
-    return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
-                                             N.getOperand(0).getValueType(), VT,
-                                             N.getOpcode());
-
-  assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
-  SDLoc dl(N);
-  SmallVector<SDValue, 8> Ops;
-  for (unsigned i = 0; i != NumElts; ++i) {
-    const APInt &CInt = N.getConstantOperandAPInt(i);
-    // Element types smaller than 32 bits are not legal, so use i32 elements.
-    // The values are implicitly truncated so sext vs. zext doesn't matter.
-    Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
-  }
-  return DAG.getBuildVector(TruncVT, dl, Ops);
+  EVT HalfVT = EVT::getVectorVT(
+      *DAG.getContext(),
+      VT.getScalarType().getHalfSizedIntegerVT(*DAG.getContext()),
+      VT.getVectorElementCount());
+  return DAG.getNode(ISD::TRUNCATE, SDLoc(N), HalfVT, N);
 }
 
 static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
@@ -5450,34 +5396,27 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
   if (IsN0ZExt && IsN1ZExt)
     return AArch64ISD::UMULL;
 
-  // Select SMULL if we can replace zext with sext.
-  if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
-      !isExtendedBUILD_VECTOR(N0, DAG, false) &&
-      !isExtendedBUILD_VECTOR(N1, DAG, false)) {
-    SDValue ZextOperand;
-    if (IsN0ZExt)
-      ZextOperand = N0.getOperand(0);
-    else
-      ZextOperand = N1.getOperand(0);
-    if (DAG.SignBitIsZero(ZextOperand)) {
-      SDValue NewSext =
-          DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
-      if (IsN0ZExt)
-        N0 = NewSext;
-      else
-        N1 = NewSext;
-      return AArch64ISD::SMULL;
-    }
-  }
-
   // Select UMULL if we can replace the other operand with an extend.
-  if (IsN0ZExt || IsN1ZExt) {
-    EVT VT = N0.getValueType();
-    APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
-                                       VT.getScalarSizeInBits() / 2);
+  EVT VT = N0.getValueType();
+  APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
+                                     VT.getScalarSizeInBits() / 2);
+  if (IsN0ZExt || IsN1ZExt)
     if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
       return AArch64ISD::UMULL;
-  }
+  // For v2i64 we look more aggresively at both operands being zero, to avoid
+  // scalarization.
+  if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) &&
+      DAG.MaskedValueIsZero(N1, Mask))
+    return AArch64ISD::UMULL;
+
+  if (IsN0SExt || IsN1SExt)
+    if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) >
+        VT.getScalarSizeInBits() / 2)
+      return AArch64ISD::SMULL;
+  if (VT == MVT::v2i64 &&
+      DAG.ComputeNumSignBits(N0) > VT.getScalarSizeInBits() / 2 &&
+      DAG.ComputeNumSignBits(N1) > VT.getScalarSizeInBits() / 2)
+    return AArch64ISD::SMULL;
 
   if (!IsN1SExt && !IsN1ZExt)
     return 0;
diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
index 3c4901ade972ec..0a38fd1488e2eb 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll
@@ -231,29 +231,24 @@ define <4 x i32> @smull_zext_v4i16_v4i32(ptr %A, ptr %B) nounwind {
 define <2 x i64> @smull_zext_v2i32_v2i64(ptr %A, ptr %B) nounwind {
 ; CHECK-NEON-LABEL: smull_zext_v2i32_v2i64:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    ldr d0, [x1]
-; CHECK-NEON-NEXT:    ldrh w9, [x0]
-; CHECK-NEON-NEXT:    ldrh w10, [x0, #2]
-; CHECK-NEON-NEXT:    sshll v0.2d, v0.2s, #0
-; CHECK-NEON-NEXT:    fmov x11, d0
-; CHECK-NEON-NEXT:    mov x8, v0.d[1]
-; CHECK-NEON-NEXT:    smull x9, w9, w11
-; CHECK-NEON-NEXT:    smull x8, w10, w8
-; CHECK-NEON-NEXT:    fmov d0, x9
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    ldrh w8, [x0]
+; CHECK-NEON-NEXT:    ldrh w9, [x0, #2]
+; CHECK-NEON-NEXT:    ldr d1, [x1]
+; CHECK-NEON-NEXT:    fmov d0, x8
+; CHECK-NEON-NEXT:    mov v0.d[1], x9
+; CHECK-NEON-NEXT:    xtn v0.2s, v0.2d
+; CHECK-NEON-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: smull_zext_v2i32_v2i64:
 ; CHECK-SVE:       // %bb.0:
 ; CHECK-SVE-NEXT:    ldrh w8, [x0]
 ; CHECK-SVE-NEXT:    ldrh w9, [x0, #2]
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    ldr d0, [x1]
-; CHECK-SVE-NEXT:    fmov d1, x8
-; CHECK-SVE-NEXT:    sshll v0.2d, v0.2s, #0
-; CHECK-SVE-NEXT:    mov v1.d[1], x9
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    ldr d1, [x1]
+; CHECK-SVE-NEXT:    fmov d0, x8
+; CHECK-SVE-NEXT:    mov v0.d[1], x9
+; CHECK-SVE-NEXT:    xtn v0.2s, v0.2d
+; CHECK-SVE-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: smull_zext_v2i32_v2i64:
@@ -2405,25 +2400,16 @@ define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) {
 define <2 x i64> @lsr(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-NEON-LABEL: lsr:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    ushr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT:    ushr v1.2d, v1.2d, #32
-; CHECK-NEON-NEXT:    fmov x10, d1
-; CHECK-NEON-NEXT:    fmov x11, d0
-; CHECK-NEON-NEXT:    mov x8, v1.d[1]
-; CHECK-NEON-NEXT:    mov x9, v0.d[1]
-; CHECK-NEON-NEXT:    umull x10, w11, w10
-; CHECK-NEON-NEXT:    umull x8, w9, w8
-; CHECK-NEON-NEXT:    fmov d0, x10
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-NEON-NEXT:    umull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: lsr:
 ; CHECK-SVE:       // %bb.0:
-; CHECK-SVE-NEXT:    ushr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT:    ushr v1.2d, v1.2d, #32
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-SVE-NEXT:    umull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: lsr:
@@ -2482,25 +2468,16 @@ define <2 x i64> @lsr_const(<2 x i64> %a, <2 x i64> %b) {
 define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-NEON-LABEL: asr:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT:    sshr v1.2d, v1.2d, #32
-; CHECK-NEON-NEXT:    fmov x10, d1
-; CHECK-NEON-NEXT:    fmov x11, d0
-; CHECK-NEON-NEXT:    mov x8, v1.d[1]
-; CHECK-NEON-NEXT:    mov x9, v0.d[1]
-; CHECK-NEON-NEXT:    smull x10, w11, w10
-; CHECK-NEON-NEXT:    smull x8, w9, w8
-; CHECK-NEON-NEXT:    fmov d0, x10
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-NEON-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: asr:
 ; CHECK-SVE:       // %bb.0:
-; CHECK-SVE-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT:    sshr v1.2d, v1.2d, #32
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT:    shrn v1.2s, v1.2d, #32
+; CHECK-SVE-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: asr:
@@ -2525,25 +2502,16 @@ define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
 define <2 x i64> @asr_const(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-NEON-LABEL: asr_const:
 ; CHECK-NEON:       // %bb.0:
-; CHECK-NEON-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-NEON-NEXT:    fmov x9, d0
-; CHECK-NEON-NEXT:    mov x8, v0.d[1]
-; CHECK-NEON-NEXT:    lsl x10, x9, #5
-; CHECK-NEON-NEXT:    lsl x11, x8, #5
-; CHECK-NEON-NEXT:    sub x9, x10, x9
-; CHECK-NEON-NEXT:    fmov d0, x9
-; CHECK-NEON-NEXT:    sub x8, x11, x8
-; CHECK-NEON-NEXT:    mov v0.d[1], x8
+; CHECK-NEON-NEXT:    movi v1.2s, #31
+; CHECK-NEON-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-NEON-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-NEON-NEXT:    ret
 ;
 ; CHECK-SVE-LABEL: asr_const:
 ; CHECK-SVE:       // %bb.0:
-; CHECK-SVE-NEXT:    mov w8, #31 // =0x1f
-; CHECK-SVE-NEXT:    sshr v0.2d, v0.2d, #32
-; CHECK-SVE-NEXT:    ptrue p0.d, vl2
-; CHECK-SVE-NEXT:    dup v1.2d, x8
-; CHECK-SVE-NEXT:    mul z0.d, p0/m, z0.d, z1.d
-; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    movi v1.2s, #31
+; CHECK-SVE-NEXT:    shrn v0.2s, v0.2d, #32
+; CHECK-SVE-NEXT:    smull v0.2d, v0.2s, v1.2s
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-GI-LABEL: asr_const:

; CHECK-NEON-NEXT: mov v0.d[1], x8
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
; CHECK-NEON-NEXT: umull v0.2d, v0.2s, v1.2s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you also plan to change the vectoriser cost model to reflect the lower cost? Looks like a fairly simple pattern that might encourage more vectorisation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No plans at the moment but this appears to come up already as an improvement. Trying to use knownbits from the cost model would not be trivial and needs to consider the possible compile time impacts. The exact tests were just artificial examples to show it working. Shifts are often the simplest way of generating known bits.

@davemgreen
Copy link
Collaborator Author

Rebase and ping. Thanks.

Copy link
Contributor

@david-arm david-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I have a few suggestions for possible minor improvements, but otherwise looks good.

}
// For v2i64 we look more aggresively at both operands being zero, to avoid
// scalarization.
if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth making this else if because if we entered the if (IsN0ZExt || IsN1ZExt) and still get here we already know that for one of N0 or N1 DAG.MaskedValueIsZero will return false?


if (IsN0SExt || IsN1SExt)
if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) >
VT.getScalarSizeInBits() / 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is it worth pulling VT.getScalarSizeInBits() into it's own variable given it's used 3 times?

if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) >
VT.getScalarSizeInBits() / 2)
return AArch64ISD::SMULL;
if (VT == MVT::v2i64 &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the umull case above, I think we can make this else if.

This attempts to clean up and improve where we generate smull using known-bits.
For v2i64 types (where no mul is present), we try to create mull more
aggressively to avoid scalarization.
@davemgreen davemgreen merged commit bca846d into llvm:main Nov 20, 2024
5 of 8 checks passed
@davemgreen davemgreen deleted the gh-a64-mullknownbits branch November 20, 2024 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants