Skip to content

Commit 548d8fe

Browse files
committed
[AArch64] Improve mull generation
This attempts to clean up and improve where we generate smull using known-bits. For v2i64 types (where no mul is present), we try to create mull more aggressively to avoid scalarization.
1 parent bdfadb1 commit 548d8fe

File tree

2 files changed

+53
-146
lines changed

2 files changed

+53
-146
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5171,40 +5171,6 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
51715171
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
51725172
}
51735173

5174-
static EVT getExtensionTo64Bits(const EVT &OrigVT) {
5175-
if (OrigVT.getSizeInBits() >= 64)
5176-
return OrigVT;
5177-
5178-
assert(OrigVT.isSimple() && "Expecting a simple value type");
5179-
5180-
MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
5181-
switch (OrigSimpleTy) {
5182-
default: llvm_unreachable("Unexpected Vector Type");
5183-
case MVT::v2i8:
5184-
case MVT::v2i16:
5185-
return MVT::v2i32;
5186-
case MVT::v4i8:
5187-
return MVT::v4i16;
5188-
}
5189-
}
5190-
5191-
static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
5192-
const EVT &OrigTy,
5193-
const EVT &ExtTy,
5194-
unsigned ExtOpcode) {
5195-
// The vector originally had a size of OrigTy. It was then extended to ExtTy.
5196-
// We expect the ExtTy to be 128-bits total. If the OrigTy is less than
5197-
// 64-bits we need to insert a new extension so that it will be 64-bits.
5198-
assert(ExtTy.is128BitVector() && "Unexpected extension size");
5199-
if (OrigTy.getSizeInBits() >= 64)
5200-
return N;
5201-
5202-
// Must extend size to at least 64 bits to be used as an operand for VMULL.
5203-
EVT NewVT = getExtensionTo64Bits(OrigTy);
5204-
5205-
return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
5206-
}
5207-
52085174
// Returns lane if Op extracts from a two-element vector and lane is constant
52095175
// (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
52105176
static std::optional<uint64_t>
@@ -5250,31 +5216,11 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
52505216
static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
52515217
EVT VT = N.getValueType();
52525218
assert(VT.is128BitVector() && "Unexpected vector MULL size");
5253-
5254-
unsigned NumElts = VT.getVectorNumElements();
5255-
unsigned OrigEltSize = VT.getScalarSizeInBits();
5256-
unsigned EltSize = OrigEltSize / 2;
5257-
MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
5258-
5259-
APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
5260-
if (DAG.MaskedValueIsZero(N, HiBits))
5261-
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
5262-
5263-
if (ISD::isExtOpcode(N.getOpcode()))
5264-
return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
5265-
N.getOperand(0).getValueType(), VT,
5266-
N.getOpcode());
5267-
5268-
assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
5269-
SDLoc dl(N);
5270-
SmallVector<SDValue, 8> Ops;
5271-
for (unsigned i = 0; i != NumElts; ++i) {
5272-
const APInt &CInt = N.getConstantOperandAPInt(i);
5273-
// Element types smaller than 32 bits are not legal, so use i32 elements.
5274-
// The values are implicitly truncated so sext vs. zext doesn't matter.
5275-
Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
5276-
}
5277-
return DAG.getBuildVector(TruncVT, dl, Ops);
5219+
EVT HalfVT = EVT::getVectorVT(
5220+
*DAG.getContext(),
5221+
VT.getScalarType().getHalfSizedIntegerVT(*DAG.getContext()),
5222+
VT.getVectorElementCount());
5223+
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), HalfVT, N);
52785224
}
52795225

52805226
static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
@@ -5450,34 +5396,27 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
54505396
if (IsN0ZExt && IsN1ZExt)
54515397
return AArch64ISD::UMULL;
54525398

5453-
// Select SMULL if we can replace zext with sext.
5454-
if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
5455-
!isExtendedBUILD_VECTOR(N0, DAG, false) &&
5456-
!isExtendedBUILD_VECTOR(N1, DAG, false)) {
5457-
SDValue ZextOperand;
5458-
if (IsN0ZExt)
5459-
ZextOperand = N0.getOperand(0);
5460-
else
5461-
ZextOperand = N1.getOperand(0);
5462-
if (DAG.SignBitIsZero(ZextOperand)) {
5463-
SDValue NewSext =
5464-
DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
5465-
if (IsN0ZExt)
5466-
N0 = NewSext;
5467-
else
5468-
N1 = NewSext;
5469-
return AArch64ISD::SMULL;
5470-
}
5471-
}
5472-
54735399
// Select UMULL if we can replace the other operand with an extend.
5474-
if (IsN0ZExt || IsN1ZExt) {
5475-
EVT VT = N0.getValueType();
5476-
APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
5477-
VT.getScalarSizeInBits() / 2);
5400+
EVT VT = N0.getValueType();
5401+
APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
5402+
VT.getScalarSizeInBits() / 2);
5403+
if (IsN0ZExt || IsN1ZExt)
54785404
if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
54795405
return AArch64ISD::UMULL;
5480-
}
5406+
// For v2i64 we look more aggresively at both operands being zero, to avoid
5407+
// scalarization.
5408+
if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) &&
5409+
DAG.MaskedValueIsZero(N1, Mask))
5410+
return AArch64ISD::UMULL;
5411+
5412+
if (IsN0SExt || IsN1SExt)
5413+
if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) >
5414+
VT.getScalarSizeInBits() / 2)
5415+
return AArch64ISD::SMULL;
5416+
if (VT == MVT::v2i64 &&
5417+
DAG.ComputeNumSignBits(N0) > VT.getScalarSizeInBits() / 2 &&
5418+
DAG.ComputeNumSignBits(N1) > VT.getScalarSizeInBits() / 2)
5419+
return AArch64ISD::SMULL;
54815420

54825421
if (!IsN1SExt && !IsN1ZExt)
54835422
return 0;

llvm/test/CodeGen/AArch64/aarch64-smull.ll

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -231,29 +231,24 @@ define <4 x i32> @smull_zext_v4i16_v4i32(ptr %A, ptr %B) nounwind {
231231
define <2 x i64> @smull_zext_v2i32_v2i64(ptr %A, ptr %B) nounwind {
232232
; CHECK-NEON-LABEL: smull_zext_v2i32_v2i64:
233233
; CHECK-NEON: // %bb.0:
234-
; CHECK-NEON-NEXT: ldr d0, [x1]
235-
; CHECK-NEON-NEXT: ldrh w9, [x0]
236-
; CHECK-NEON-NEXT: ldrh w10, [x0, #2]
237-
; CHECK-NEON-NEXT: sshll v0.2d, v0.2s, #0
238-
; CHECK-NEON-NEXT: fmov x11, d0
239-
; CHECK-NEON-NEXT: mov x8, v0.d[1]
240-
; CHECK-NEON-NEXT: smull x9, w9, w11
241-
; CHECK-NEON-NEXT: smull x8, w10, w8
242-
; CHECK-NEON-NEXT: fmov d0, x9
243-
; CHECK-NEON-NEXT: mov v0.d[1], x8
234+
; CHECK-NEON-NEXT: ldrh w8, [x0]
235+
; CHECK-NEON-NEXT: ldrh w9, [x0, #2]
236+
; CHECK-NEON-NEXT: ldr d1, [x1]
237+
; CHECK-NEON-NEXT: fmov d0, x8
238+
; CHECK-NEON-NEXT: mov v0.d[1], x9
239+
; CHECK-NEON-NEXT: xtn v0.2s, v0.2d
240+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
244241
; CHECK-NEON-NEXT: ret
245242
;
246243
; CHECK-SVE-LABEL: smull_zext_v2i32_v2i64:
247244
; CHECK-SVE: // %bb.0:
248245
; CHECK-SVE-NEXT: ldrh w8, [x0]
249246
; CHECK-SVE-NEXT: ldrh w9, [x0, #2]
250-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
251-
; CHECK-SVE-NEXT: ldr d0, [x1]
252-
; CHECK-SVE-NEXT: fmov d1, x8
253-
; CHECK-SVE-NEXT: sshll v0.2d, v0.2s, #0
254-
; CHECK-SVE-NEXT: mov v1.d[1], x9
255-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
256-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
247+
; CHECK-SVE-NEXT: ldr d1, [x1]
248+
; CHECK-SVE-NEXT: fmov d0, x8
249+
; CHECK-SVE-NEXT: mov v0.d[1], x9
250+
; CHECK-SVE-NEXT: xtn v0.2s, v0.2d
251+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
257252
; CHECK-SVE-NEXT: ret
258253
;
259254
; CHECK-GI-LABEL: smull_zext_v2i32_v2i64:
@@ -2405,25 +2400,16 @@ define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) {
24052400
define <2 x i64> @lsr(<2 x i64> %a, <2 x i64> %b) {
24062401
; CHECK-NEON-LABEL: lsr:
24072402
; CHECK-NEON: // %bb.0:
2408-
; CHECK-NEON-NEXT: ushr v0.2d, v0.2d, #32
2409-
; CHECK-NEON-NEXT: ushr v1.2d, v1.2d, #32
2410-
; CHECK-NEON-NEXT: fmov x10, d1
2411-
; CHECK-NEON-NEXT: fmov x11, d0
2412-
; CHECK-NEON-NEXT: mov x8, v1.d[1]
2413-
; CHECK-NEON-NEXT: mov x9, v0.d[1]
2414-
; CHECK-NEON-NEXT: umull x10, w11, w10
2415-
; CHECK-NEON-NEXT: umull x8, w9, w8
2416-
; CHECK-NEON-NEXT: fmov d0, x10
2417-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2403+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2404+
; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
2405+
; CHECK-NEON-NEXT: umull v0.2d, v0.2s, v1.2s
24182406
; CHECK-NEON-NEXT: ret
24192407
;
24202408
; CHECK-SVE-LABEL: lsr:
24212409
; CHECK-SVE: // %bb.0:
2422-
; CHECK-SVE-NEXT: ushr v0.2d, v0.2d, #32
2423-
; CHECK-SVE-NEXT: ushr v1.2d, v1.2d, #32
2424-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2425-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2426-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2410+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2411+
; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
2412+
; CHECK-SVE-NEXT: umull v0.2d, v0.2s, v1.2s
24272413
; CHECK-SVE-NEXT: ret
24282414
;
24292415
; CHECK-GI-LABEL: lsr:
@@ -2482,25 +2468,16 @@ define <2 x i64> @lsr_const(<2 x i64> %a, <2 x i64> %b) {
24822468
define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
24832469
; CHECK-NEON-LABEL: asr:
24842470
; CHECK-NEON: // %bb.0:
2485-
; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
2486-
; CHECK-NEON-NEXT: sshr v1.2d, v1.2d, #32
2487-
; CHECK-NEON-NEXT: fmov x10, d1
2488-
; CHECK-NEON-NEXT: fmov x11, d0
2489-
; CHECK-NEON-NEXT: mov x8, v1.d[1]
2490-
; CHECK-NEON-NEXT: mov x9, v0.d[1]
2491-
; CHECK-NEON-NEXT: smull x10, w11, w10
2492-
; CHECK-NEON-NEXT: smull x8, w9, w8
2493-
; CHECK-NEON-NEXT: fmov d0, x10
2494-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2471+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2472+
; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
2473+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
24952474
; CHECK-NEON-NEXT: ret
24962475
;
24972476
; CHECK-SVE-LABEL: asr:
24982477
; CHECK-SVE: // %bb.0:
2499-
; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
2500-
; CHECK-SVE-NEXT: sshr v1.2d, v1.2d, #32
2501-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2502-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2503-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2478+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2479+
; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
2480+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
25042481
; CHECK-SVE-NEXT: ret
25052482
;
25062483
; CHECK-GI-LABEL: asr:
@@ -2525,25 +2502,16 @@ define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
25252502
define <2 x i64> @asr_const(<2 x i64> %a, <2 x i64> %b) {
25262503
; CHECK-NEON-LABEL: asr_const:
25272504
; CHECK-NEON: // %bb.0:
2528-
; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
2529-
; CHECK-NEON-NEXT: fmov x9, d0
2530-
; CHECK-NEON-NEXT: mov x8, v0.d[1]
2531-
; CHECK-NEON-NEXT: lsl x10, x9, #5
2532-
; CHECK-NEON-NEXT: lsl x11, x8, #5
2533-
; CHECK-NEON-NEXT: sub x9, x10, x9
2534-
; CHECK-NEON-NEXT: fmov d0, x9
2535-
; CHECK-NEON-NEXT: sub x8, x11, x8
2536-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2505+
; CHECK-NEON-NEXT: movi v1.2s, #31
2506+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2507+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
25372508
; CHECK-NEON-NEXT: ret
25382509
;
25392510
; CHECK-SVE-LABEL: asr_const:
25402511
; CHECK-SVE: // %bb.0:
2541-
; CHECK-SVE-NEXT: mov w8, #31 // =0x1f
2542-
; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
2543-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2544-
; CHECK-SVE-NEXT: dup v1.2d, x8
2545-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2546-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2512+
; CHECK-SVE-NEXT: movi v1.2s, #31
2513+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2514+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
25472515
; CHECK-SVE-NEXT: ret
25482516
;
25492517
; CHECK-GI-LABEL: asr_const:

0 commit comments

Comments
 (0)