Skip to content

Commit 2630882

Browse files
committed
[AArch64] Improve mull generation
This attempts to clean up and improve where we generate smull using known-bits. For v2i64 types (where no mul is present), we try to create mull more aggressively to avoid scalarization.
1 parent debfd7b commit 2630882

File tree

2 files changed

+53
-146
lines changed

2 files changed

+53
-146
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 23 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5173,40 +5173,6 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
51735173
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
51745174
}
51755175

5176-
static EVT getExtensionTo64Bits(const EVT &OrigVT) {
5177-
if (OrigVT.getSizeInBits() >= 64)
5178-
return OrigVT;
5179-
5180-
assert(OrigVT.isSimple() && "Expecting a simple value type");
5181-
5182-
MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
5183-
switch (OrigSimpleTy) {
5184-
default: llvm_unreachable("Unexpected Vector Type");
5185-
case MVT::v2i8:
5186-
case MVT::v2i16:
5187-
return MVT::v2i32;
5188-
case MVT::v4i8:
5189-
return MVT::v4i16;
5190-
}
5191-
}
5192-
5193-
static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
5194-
const EVT &OrigTy,
5195-
const EVT &ExtTy,
5196-
unsigned ExtOpcode) {
5197-
// The vector originally had a size of OrigTy. It was then extended to ExtTy.
5198-
// We expect the ExtTy to be 128-bits total. If the OrigTy is less than
5199-
// 64-bits we need to insert a new extension so that it will be 64-bits.
5200-
assert(ExtTy.is128BitVector() && "Unexpected extension size");
5201-
if (OrigTy.getSizeInBits() >= 64)
5202-
return N;
5203-
5204-
// Must extend size to at least 64 bits to be used as an operand for VMULL.
5205-
EVT NewVT = getExtensionTo64Bits(OrigTy);
5206-
5207-
return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
5208-
}
5209-
52105176
// Returns lane if Op extracts from a two-element vector and lane is constant
52115177
// (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
52125178
static std::optional<uint64_t>
@@ -5252,31 +5218,11 @@ static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
52525218
static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
52535219
EVT VT = N.getValueType();
52545220
assert(VT.is128BitVector() && "Unexpected vector MULL size");
5255-
5256-
unsigned NumElts = VT.getVectorNumElements();
5257-
unsigned OrigEltSize = VT.getScalarSizeInBits();
5258-
unsigned EltSize = OrigEltSize / 2;
5259-
MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
5260-
5261-
APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
5262-
if (DAG.MaskedValueIsZero(N, HiBits))
5263-
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
5264-
5265-
if (ISD::isExtOpcode(N.getOpcode()))
5266-
return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
5267-
N.getOperand(0).getValueType(), VT,
5268-
N.getOpcode());
5269-
5270-
assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
5271-
SDLoc dl(N);
5272-
SmallVector<SDValue, 8> Ops;
5273-
for (unsigned i = 0; i != NumElts; ++i) {
5274-
const APInt &CInt = N.getConstantOperandAPInt(i);
5275-
// Element types smaller than 32 bits are not legal, so use i32 elements.
5276-
// The values are implicitly truncated so sext vs. zext doesn't matter.
5277-
Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
5278-
}
5279-
return DAG.getBuildVector(TruncVT, dl, Ops);
5221+
EVT HalfVT = EVT::getVectorVT(
5222+
*DAG.getContext(),
5223+
VT.getScalarType().getHalfSizedIntegerVT(*DAG.getContext()),
5224+
VT.getVectorElementCount());
5225+
return DAG.getNode(ISD::TRUNCATE, SDLoc(N), HalfVT, N);
52805226
}
52815227

52825228
static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
@@ -5452,34 +5398,27 @@ static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
54525398
if (IsN0ZExt && IsN1ZExt)
54535399
return AArch64ISD::UMULL;
54545400

5455-
// Select SMULL if we can replace zext with sext.
5456-
if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
5457-
!isExtendedBUILD_VECTOR(N0, DAG, false) &&
5458-
!isExtendedBUILD_VECTOR(N1, DAG, false)) {
5459-
SDValue ZextOperand;
5460-
if (IsN0ZExt)
5461-
ZextOperand = N0.getOperand(0);
5462-
else
5463-
ZextOperand = N1.getOperand(0);
5464-
if (DAG.SignBitIsZero(ZextOperand)) {
5465-
SDValue NewSext =
5466-
DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
5467-
if (IsN0ZExt)
5468-
N0 = NewSext;
5469-
else
5470-
N1 = NewSext;
5471-
return AArch64ISD::SMULL;
5472-
}
5473-
}
5474-
54755401
// Select UMULL if we can replace the other operand with an extend.
5476-
if (IsN0ZExt || IsN1ZExt) {
5477-
EVT VT = N0.getValueType();
5478-
APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
5479-
VT.getScalarSizeInBits() / 2);
5402+
EVT VT = N0.getValueType();
5403+
APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
5404+
VT.getScalarSizeInBits() / 2);
5405+
if (IsN0ZExt || IsN1ZExt)
54805406
if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
54815407
return AArch64ISD::UMULL;
5482-
}
5408+
// For v2i64 we look more aggresively at both operands being zero, to avoid
5409+
// scalarization.
5410+
if (VT == MVT::v2i64 && DAG.MaskedValueIsZero(N0, Mask) &&
5411+
DAG.MaskedValueIsZero(N1, Mask))
5412+
return AArch64ISD::UMULL;
5413+
5414+
if (IsN0SExt || IsN1SExt)
5415+
if (DAG.ComputeNumSignBits(IsN0SExt ? N1 : N0) >
5416+
VT.getScalarSizeInBits() / 2)
5417+
return AArch64ISD::SMULL;
5418+
if (VT == MVT::v2i64 &&
5419+
DAG.ComputeNumSignBits(N0) > VT.getScalarSizeInBits() / 2 &&
5420+
DAG.ComputeNumSignBits(N1) > VT.getScalarSizeInBits() / 2)
5421+
return AArch64ISD::SMULL;
54835422

54845423
if (!IsN1SExt && !IsN1ZExt)
54855424
return 0;

llvm/test/CodeGen/AArch64/aarch64-smull.ll

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -231,29 +231,24 @@ define <4 x i32> @smull_zext_v4i16_v4i32(ptr %A, ptr %B) nounwind {
231231
define <2 x i64> @smull_zext_v2i32_v2i64(ptr %A, ptr %B) nounwind {
232232
; CHECK-NEON-LABEL: smull_zext_v2i32_v2i64:
233233
; CHECK-NEON: // %bb.0:
234-
; CHECK-NEON-NEXT: ldr d0, [x1]
235-
; CHECK-NEON-NEXT: ldrh w9, [x0]
236-
; CHECK-NEON-NEXT: ldrh w10, [x0, #2]
237-
; CHECK-NEON-NEXT: sshll v0.2d, v0.2s, #0
238-
; CHECK-NEON-NEXT: fmov x11, d0
239-
; CHECK-NEON-NEXT: mov x8, v0.d[1]
240-
; CHECK-NEON-NEXT: smull x9, w9, w11
241-
; CHECK-NEON-NEXT: smull x8, w10, w8
242-
; CHECK-NEON-NEXT: fmov d0, x9
243-
; CHECK-NEON-NEXT: mov v0.d[1], x8
234+
; CHECK-NEON-NEXT: ldrh w8, [x0]
235+
; CHECK-NEON-NEXT: ldrh w9, [x0, #2]
236+
; CHECK-NEON-NEXT: ldr d1, [x1]
237+
; CHECK-NEON-NEXT: fmov d0, x8
238+
; CHECK-NEON-NEXT: mov v0.d[1], x9
239+
; CHECK-NEON-NEXT: xtn v0.2s, v0.2d
240+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
244241
; CHECK-NEON-NEXT: ret
245242
;
246243
; CHECK-SVE-LABEL: smull_zext_v2i32_v2i64:
247244
; CHECK-SVE: // %bb.0:
248245
; CHECK-SVE-NEXT: ldrh w8, [x0]
249246
; CHECK-SVE-NEXT: ldrh w9, [x0, #2]
250-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
251-
; CHECK-SVE-NEXT: ldr d0, [x1]
252-
; CHECK-SVE-NEXT: fmov d1, x8
253-
; CHECK-SVE-NEXT: sshll v0.2d, v0.2s, #0
254-
; CHECK-SVE-NEXT: mov v1.d[1], x9
255-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
256-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
247+
; CHECK-SVE-NEXT: ldr d1, [x1]
248+
; CHECK-SVE-NEXT: fmov d0, x8
249+
; CHECK-SVE-NEXT: mov v0.d[1], x9
250+
; CHECK-SVE-NEXT: xtn v0.2s, v0.2d
251+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
257252
; CHECK-SVE-NEXT: ret
258253
;
259254
; CHECK-GI-LABEL: smull_zext_v2i32_v2i64:
@@ -2405,25 +2400,16 @@ define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) {
24052400
define <2 x i64> @lsr(<2 x i64> %a, <2 x i64> %b) {
24062401
; CHECK-NEON-LABEL: lsr:
24072402
; CHECK-NEON: // %bb.0:
2408-
; CHECK-NEON-NEXT: ushr v0.2d, v0.2d, #32
2409-
; CHECK-NEON-NEXT: ushr v1.2d, v1.2d, #32
2410-
; CHECK-NEON-NEXT: fmov x10, d1
2411-
; CHECK-NEON-NEXT: fmov x11, d0
2412-
; CHECK-NEON-NEXT: mov x8, v1.d[1]
2413-
; CHECK-NEON-NEXT: mov x9, v0.d[1]
2414-
; CHECK-NEON-NEXT: umull x10, w11, w10
2415-
; CHECK-NEON-NEXT: umull x8, w9, w8
2416-
; CHECK-NEON-NEXT: fmov d0, x10
2417-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2403+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2404+
; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
2405+
; CHECK-NEON-NEXT: umull v0.2d, v0.2s, v1.2s
24182406
; CHECK-NEON-NEXT: ret
24192407
;
24202408
; CHECK-SVE-LABEL: lsr:
24212409
; CHECK-SVE: // %bb.0:
2422-
; CHECK-SVE-NEXT: ushr v0.2d, v0.2d, #32
2423-
; CHECK-SVE-NEXT: ushr v1.2d, v1.2d, #32
2424-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2425-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2426-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2410+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2411+
; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
2412+
; CHECK-SVE-NEXT: umull v0.2d, v0.2s, v1.2s
24272413
; CHECK-SVE-NEXT: ret
24282414
;
24292415
; CHECK-GI-LABEL: lsr:
@@ -2482,25 +2468,16 @@ define <2 x i64> @lsr_const(<2 x i64> %a, <2 x i64> %b) {
24822468
define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
24832469
; CHECK-NEON-LABEL: asr:
24842470
; CHECK-NEON: // %bb.0:
2485-
; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
2486-
; CHECK-NEON-NEXT: sshr v1.2d, v1.2d, #32
2487-
; CHECK-NEON-NEXT: fmov x10, d1
2488-
; CHECK-NEON-NEXT: fmov x11, d0
2489-
; CHECK-NEON-NEXT: mov x8, v1.d[1]
2490-
; CHECK-NEON-NEXT: mov x9, v0.d[1]
2491-
; CHECK-NEON-NEXT: smull x10, w11, w10
2492-
; CHECK-NEON-NEXT: smull x8, w9, w8
2493-
; CHECK-NEON-NEXT: fmov d0, x10
2494-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2471+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2472+
; CHECK-NEON-NEXT: shrn v1.2s, v1.2d, #32
2473+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
24952474
; CHECK-NEON-NEXT: ret
24962475
;
24972476
; CHECK-SVE-LABEL: asr:
24982477
; CHECK-SVE: // %bb.0:
2499-
; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
2500-
; CHECK-SVE-NEXT: sshr v1.2d, v1.2d, #32
2501-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2502-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2503-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2478+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2479+
; CHECK-SVE-NEXT: shrn v1.2s, v1.2d, #32
2480+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
25042481
; CHECK-SVE-NEXT: ret
25052482
;
25062483
; CHECK-GI-LABEL: asr:
@@ -2525,25 +2502,16 @@ define <2 x i64> @asr(<2 x i64> %a, <2 x i64> %b) {
25252502
define <2 x i64> @asr_const(<2 x i64> %a, <2 x i64> %b) {
25262503
; CHECK-NEON-LABEL: asr_const:
25272504
; CHECK-NEON: // %bb.0:
2528-
; CHECK-NEON-NEXT: sshr v0.2d, v0.2d, #32
2529-
; CHECK-NEON-NEXT: fmov x9, d0
2530-
; CHECK-NEON-NEXT: mov x8, v0.d[1]
2531-
; CHECK-NEON-NEXT: lsl x10, x9, #5
2532-
; CHECK-NEON-NEXT: lsl x11, x8, #5
2533-
; CHECK-NEON-NEXT: sub x9, x10, x9
2534-
; CHECK-NEON-NEXT: fmov d0, x9
2535-
; CHECK-NEON-NEXT: sub x8, x11, x8
2536-
; CHECK-NEON-NEXT: mov v0.d[1], x8
2505+
; CHECK-NEON-NEXT: movi v1.2s, #31
2506+
; CHECK-NEON-NEXT: shrn v0.2s, v0.2d, #32
2507+
; CHECK-NEON-NEXT: smull v0.2d, v0.2s, v1.2s
25372508
; CHECK-NEON-NEXT: ret
25382509
;
25392510
; CHECK-SVE-LABEL: asr_const:
25402511
; CHECK-SVE: // %bb.0:
2541-
; CHECK-SVE-NEXT: mov w8, #31 // =0x1f
2542-
; CHECK-SVE-NEXT: sshr v0.2d, v0.2d, #32
2543-
; CHECK-SVE-NEXT: ptrue p0.d, vl2
2544-
; CHECK-SVE-NEXT: dup v1.2d, x8
2545-
; CHECK-SVE-NEXT: mul z0.d, p0/m, z0.d, z1.d
2546-
; CHECK-SVE-NEXT: // kill: def $q0 killed $q0 killed $z0
2512+
; CHECK-SVE-NEXT: movi v1.2s, #31
2513+
; CHECK-SVE-NEXT: shrn v0.2s, v0.2d, #32
2514+
; CHECK-SVE-NEXT: smull v0.2d, v0.2s, v1.2s
25472515
; CHECK-SVE-NEXT: ret
25482516
;
25492517
; CHECK-GI-LABEL: asr_const:

0 commit comments

Comments
 (0)