From cc9dfdcbc1f753656ca91bbfa077fceafdd1304e Mon Sep 17 00:00:00 2001 From: David Green Date: Wed, 20 Nov 2024 09:10:29 +0000 Subject: [PATCH] [AArch64] Improve mull generation This attempts to clean up and improve where we generate smull using known-bits. For v2i64 types (where no mul is present), we try to create mull more aggressively to avoid scalarization. --- .../Target/AArch64/AArch64ISelLowering.cpp | 103 ++++-------------- llvm/test/CodeGen/AArch64/aarch64-smull.ll | 92 +++++----------- 2 files changed, 51 insertions(+), 144 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index ad1d1237aa25a..2ec0e0bb7dff7 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -5186,40 +5186,6 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op); } -static EVT getExtensionTo64Bits(const EVT &OrigVT) { - if (OrigVT.getSizeInBits() >= 64) - return OrigVT; - - assert(OrigVT.isSimple() && "Expecting a simple value type"); - - MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy; - switch (OrigSimpleTy) { - default: llvm_unreachable("Unexpected Vector Type"); - case MVT::v2i8: - case MVT::v2i16: - return MVT::v2i32; - case MVT::v4i8: - return MVT::v4i16; - } -} - -static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG, - const EVT &OrigTy, - const EVT &ExtTy, - unsigned ExtOpcode) { - // The vector originally had a size of OrigTy. It was then extended to ExtTy. - // We expect the ExtTy to be 128-bits total. If the OrigTy is less than - // 64-bits we need to insert a new extension so that it will be 64-bits. - assert(ExtTy.is128BitVector() && "Unexpected extension size"); - if (OrigTy.getSizeInBits() >= 64) - return N; - - // Must extend size to at least 64 bits to be used as an operand for VMULL. - EVT NewVT = getExtensionTo64Bits(OrigTy); - - return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N); -} - // Returns lane if Op extracts from a two-element vector and lane is constant // (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise. static std::optional @@ -5265,31 +5231,11 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG, static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) { EVT VT = N.getValueType(); assert(VT.is128BitVector() && "Unexpected vector MULL size"); - - unsigned NumElts = VT.getVectorNumElements(); - unsigned OrigEltSize = VT.getScalarSizeInBits(); - unsigned EltSize = OrigEltSize / 2; - MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts); - - APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize); - if (DAG.MaskedValueIsZero(N, HiBits)) - return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N); - - if (ISD::isExtOpcode(N.getOpcode())) - return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG, - N.getOperand(0).getValueType(), VT, - N.getOpcode()); - - assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR"); - SDLoc dl(N); - SmallVector Ops; - for (unsigned i = 0; i != NumElts; ++i) { - const APInt &CInt = N.getConstantOperandAPInt(i); - // Element types smaller than 32 bits are not legal, so use i32 elements. - // The values are implicitly truncated so sext vs. zext doesn't matter. - Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32)); - } - return DAG.getBuildVector(TruncVT, dl, Ops); + EVT HalfVT = EVT::getVectorVT( + *DAG.getContext(), + VT.getScalarType().getHalfSizedIntegerVT(*DAG.getContext()), + VT.getVectorElementCount()); + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), HalfVT, N); } static bool isSignExtended(SDValue N, SelectionDAG &DAG) { @@ -5465,33 +5411,26 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG, if (IsN0ZExt && IsN1ZExt) return AArch64ISD::UMULL; - // Select SMULL if we can replace zext with sext. - if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) && - !isExtendedBUILD_VECTOR(N0, DAG, false) && - !isExtendedBUILD_VECTOR(N1, DAG, false)) { - SDValue ZextOperand; - if (IsN0ZExt) - ZextOperand = N0.getOperand(0); - else - ZextOperand = N1.getOperand(0); - if (DAG.SignBitIsZero(ZextOperand)) { - SDValue NewSext = - DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType()); - if (IsN0ZExt) - N0 = NewSext; - else - N1 = NewSext; - return AArch64ISD::SMULL; - } - } - // Select UMULL if we can replace the other operand with an extend. + EVT VT = N0.getValueType(); + unsigned EltSize = VT.getScalarSizeInBits(); + APInt Mask = APInt::getHighBitsSet(EltSize, EltSize / 2); if (IsN0ZExt || IsN1ZExt) { - EVT VT = N0.getValueType(); - APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(), - VT.getScalarSizeInBits() / 2); if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask)) return AArch64ISD::UMULL; + } else if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) && + DAG.MaskedValueIsZero(N1, Mask)) { + // For v2i64 we look more aggresively at both operands being zero, to avoid + // scalarization. + return AArch64ISD::UMULL; + } + + if (IsN0SExt || IsN1SExt) { + if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) > EltSize / 2) + return AArch64ISD::SMULL; + } else if (VT == MVT::v2i64 && DAG.ComputeNumSignBits(N0) > EltSize / 2 && + DAG.ComputeNumSignBits(N1) > EltSize / 2) { + return AArch64ISD::SMULL; } if (!IsN1SExt && !IsN1ZExt) diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll index 69a69dbd3b18b..0fe2bbe2c449f 100644 --- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -231,29 +231,24 @@ define <4 x i32> @smull_zext_v4i16_v4i32(ptr %A, ptr %B) nounwind { define <2 x i64> @smull_zext_v2i32_v2i64(ptr %A, ptr %B) nounwind { ; CHECK-NEON-LABEL: smull_zext_v2i32_v2i64: ; CHECK-NEON: // %bb.0: -; CHECK-NEON-NEXT: ldr d0, [x1] -; CHECK-NEON-NEXT: ldrh w9, [x0] -; CHECK-NEON-NEXT: ldrh w10, [x0, #2] -; CHECK-NEON-NEXT: sshll v0.2d, v0.2s, #0 -; CHECK-NEON-NEXT: fmov x11, d0 -; CHECK-NEON-NEXT: mov x8, v0.d[1] -; CHECK-NEON-NEXT: smull x9, w9, w11 -; CHECK-NEON-NEXT: smull x8, w10, w8 -; CHECK-NEON-NEXT: fmov d0, x9 -; CHECK-NEON-NEXT: mov v0.d[1], x8 +; CHECK-NEON-NEXT: ldrh w8, [x0] +; CHECK-NEON-NEXT: ldrh w9, [x0, #2] +; CHECK-NEON-NEXT: ldr d1, [x1] +; CHECK-NEON-NEXT: fmov d0, x8 +; CHECK-NEON-NEXT: mov v0.d[1], x9 +; CHECK-NEON-NEXT: xtn v0.2s, v0.2d +; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEON-NEXT: ret ; ; CHECK-SVE-LABEL: smull_zext_v2i32_v2i64: ; CHECK-SVE: // %bb.0: ; CHECK-SVE-NEXT: ldrh w8, [x0] ; CHECK-SVE-NEXT: ldrh w9, [x0, #2] -; CHECK-SVE-NEXT: ptrue p0.d, vl2 -; CHECK-SVE-NEXT: ldr d0, [x1] -; CHECK-SVE-NEXT: fmov d1, x8 -; CHECK-SVE-NEXT: sshll v0.2d, v0.2s, #0 -; CHECK-SVE-NEXT: mov v1.d[1], x9 -; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-SVE-NEXT: ldr d1, [x1] +; CHECK-SVE-NEXT: fmov d0, x8 +; CHECK-SVE-NEXT: mov v0.d[1], x9 +; CHECK-SVE-NEXT: xtn v0.2s, v0.2d +; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-SVE-NEXT: ret ; ; CHECK-GI-LABEL: smull_zext_v2i32_v2i64: @@ -2404,25 +2399,16 @@ define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) { define <2 x i64> @lsr(<2 x i64> %a, <2 x i64> %b) { ; CHECK-NEON-LABEL: lsr: ; CHECK-NEON: // %bb.0: -; CHECK-NEON-NEXT: ushr v0.2d, v0.2d, #32 -; CHECK-NEON-NEXT: ushr v1.2d, v1.2d, #32 -; CHECK-NEON-NEXT: fmov x10, d1 -; CHECK-NEON-NEXT: fmov x11, d0 -; CHECK-NEON-NEXT: mov x8, v1.d[1] -; CHECK-NEON-NEXT: mov x9, v0.d[1] -; CHECK-NEON-NEXT: umull x10, w11, w10 -; CHECK-NEON-NEXT: umull x8, w9, w8 -; CHECK-NEON-NEXT: fmov d0, x10 -; CHECK-NEON-NEXT: mov v0.d[1], x8 +; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32 +; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32 +; CHECK-NEON-NEXT: umull v0.2d, v0.2s, v1.2s ; CHECK-NEON-NEXT: ret ; ; CHECK-SVE-LABEL: lsr: ; CHECK-SVE: // %bb.0: -; CHECK-SVE-NEXT: ushr v0.2d, v0.2d, #32 -; CHECK-SVE-NEXT: ushr v1.2d, v1.2d, #32 -; CHECK-SVE-NEXT: ptrue p0.d, vl2 -; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32 +; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32 +; CHECK-SVE-NEXT: umull v0.2d, v0.2s, v1.2s ; CHECK-SVE-NEXT: ret ; ; CHECK-GI-LABEL: lsr: @@ -2481,25 +2467,16 @@ define <2 x i64> @lsr_const(<2 x i64> %a, <2 x i64> %b) { define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) { ; CHECK-NEON-LABEL: asr: ; CHECK-NEON: // %bb.0: -; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32 -; CHECK-NEON-NEXT: sshr v1.2d, v1.2d, #32 -; CHECK-NEON-NEXT: fmov x10, d1 -; CHECK-NEON-NEXT: fmov x11, d0 -; CHECK-NEON-NEXT: mov x8, v1.d[1] -; CHECK-NEON-NEXT: mov x9, v0.d[1] -; CHECK-NEON-NEXT: smull x10, w11, w10 -; CHECK-NEON-NEXT: smull x8, w9, w8 -; CHECK-NEON-NEXT: fmov d0, x10 -; CHECK-NEON-NEXT: mov v0.d[1], x8 +; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32 +; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32 +; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEON-NEXT: ret ; ; CHECK-SVE-LABEL: asr: ; CHECK-SVE: // %bb.0: -; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32 -; CHECK-SVE-NEXT: sshr v1.2d, v1.2d, #32 -; CHECK-SVE-NEXT: ptrue p0.d, vl2 -; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32 +; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32 +; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-SVE-NEXT: ret ; ; CHECK-GI-LABEL: asr: @@ -2524,25 +2501,16 @@ define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) { define <2 x i64> @asr_const(<2 x i64> %a, <2 x i64> %b) { ; CHECK-NEON-LABEL: asr_const: ; CHECK-NEON: // %bb.0: -; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32 -; CHECK-NEON-NEXT: fmov x9, d0 -; CHECK-NEON-NEXT: mov x8, v0.d[1] -; CHECK-NEON-NEXT: lsl x10, x9, #5 -; CHECK-NEON-NEXT: lsl x11, x8, #5 -; CHECK-NEON-NEXT: sub x9, x10, x9 -; CHECK-NEON-NEXT: fmov d0, x9 -; CHECK-NEON-NEXT: sub x8, x11, x8 -; CHECK-NEON-NEXT: mov v0.d[1], x8 +; CHECK-NEON-NEXT: movi v1.2s, #31 +; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32 +; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEON-NEXT: ret ; ; CHECK-SVE-LABEL: asr_const: ; CHECK-SVE: // %bb.0: -; CHECK-SVE-NEXT: mov w8, #31 // =0x1f -; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32 -; CHECK-SVE-NEXT: ptrue p0.d, vl2 -; CHECK-SVE-NEXT: dup v1.2d, x8 -; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-SVE-NEXT: movi v1.2s, #31 +; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32 +; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-SVE-NEXT: ret ; ; CHECK-GI-LABEL: asr_const: