Skip to content

Commit fba55c8

Browse files
authored
[X86] Fold X * 1 + Z --> X + Z for VPMADD52L (#158516)
This patch implements the fold `lo(X * 1) + Z --> lo(X) + Z --> X iff X == lo(X)`.
1 parent a254f65 commit fba55c8

File tree

2 files changed

+80
-1
lines changed

2 files changed

+80
-1
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60201,8 +60201,30 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
6020160201
static SDValue combineVPMADD52LH(SDNode *N, SelectionDAG &DAG,
6020260202
TargetLowering::DAGCombinerInfo &DCI) {
6020360203
MVT VT = N->getSimpleValueType(0);
60204-
unsigned NumEltBits = VT.getScalarSizeInBits();
60204+
60205+
bool AddLow = N->getOpcode() == X86ISD::VPMADD52L;
60206+
SDValue Op0 = N->getOperand(0);
60207+
SDValue Op1 = N->getOperand(1);
60208+
SDValue Op2 = N->getOperand(2);
60209+
SDLoc DL(N);
60210+
60211+
APInt C0, C1;
60212+
bool HasC0 = X86::isConstantSplat(Op0, C0),
60213+
HasC1 = X86::isConstantSplat(Op1, C1);
60214+
60215+
// lo/hi(C * X) + Z --> lo/hi(X * C) + Z
60216+
if (HasC0 && !HasC1)
60217+
return DAG.getNode(N->getOpcode(), DL, VT, Op1, Op0, Op2);
60218+
60219+
// lo(X * 1) + Z --> lo(X) + Z iff X == lo(X)
60220+
if (AddLow && HasC1 && C1.trunc(52).isOne()) {
60221+
KnownBits KnownOp0 = DAG.computeKnownBits(Op0);
60222+
if (KnownOp0.countMinLeadingZeros() >= 12)
60223+
return DAG.getNode(ISD::ADD, DL, VT, Op0, Op2);
60224+
}
60225+
6020560226
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
60227+
unsigned NumEltBits = VT.getScalarSizeInBits();
6020660228
if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
6020760229
DCI))
6020860230
return SDValue(N, 0);

llvm/test/CodeGen/X86/combine-vpmadd52.ll

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,60 @@ define <2 x i64> @test3_knownbits_vpmadd52h_negative(<2 x i64> %x0, <2 x i64> %x
398398
%ret = and <2 x i64> %madd, splat (i64 1)
399399
ret <2 x i64> %ret
400400
}
401+
402+
define <2 x i64> @test_vpmadd52l_mul_one(<2 x i64> %x0, <2 x i32> %x1) {
403+
; CHECK-LABEL: test_vpmadd52l_mul_one:
404+
; CHECK: # %bb.0:
405+
; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
406+
; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0
407+
; CHECK-NEXT: retq
408+
%ext = zext <2 x i32> %x1 to <2 x i64>
409+
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %ext)
410+
ret <2 x i64> %ifma
411+
}
412+
413+
define <2 x i64> @test_vpmadd52l_mul_one_commuted(<2 x i64> %x0, <2 x i32> %x1) {
414+
; CHECK-LABEL: test_vpmadd52l_mul_one_commuted:
415+
; CHECK: # %bb.0:
416+
; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
417+
; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0
418+
; CHECK-NEXT: retq
419+
%ext = zext <2 x i32> %x1 to <2 x i64>
420+
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %ext, <2 x i64> splat(i64 1))
421+
ret <2 x i64> %ifma
422+
}
423+
424+
define <2 x i64> @test_vpmadd52l_mul_one_no_mask(<2 x i64> %x0, <2 x i64> %x1) {
425+
; AVX512-LABEL: test_vpmadd52l_mul_one_no_mask:
426+
; AVX512: # %bb.0:
427+
; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm0
428+
; AVX512-NEXT: retq
429+
;
430+
; AVX-LABEL: test_vpmadd52l_mul_one_no_mask:
431+
; AVX: # %bb.0:
432+
; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
433+
; AVX-NEXT: retq
434+
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1)
435+
ret <2 x i64> %ifma
436+
}
437+
438+
; Mul by (1 << 52) + 1
439+
define <2 x i64> @test_vpmadd52l_mul_one_in_52bits(<2 x i64> %x0, <2 x i32> %x1) {
440+
; CHECK-LABEL: test_vpmadd52l_mul_one_in_52bits:
441+
; CHECK: # %bb.0:
442+
; CHECK-NEXT: vpmovzxdq {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero
443+
; CHECK-NEXT: vpaddq %xmm0, %xmm1, %xmm0
444+
; CHECK-NEXT: retq
445+
%ext = zext <2 x i32> %x1 to <2 x i64>
446+
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 4503599627370497), <2 x i64> %ext)
447+
ret <2 x i64> %ifma
448+
}
449+
450+
; lo(x1) * 1 = lo(x1), the high 52 bits are zeroes still.
451+
define <2 x i64> @test_vpmadd52h_mul_one(<2 x i64> %x0, <2 x i64> %x1) {
452+
; CHECK-LABEL: test_vpmadd52h_mul_one:
453+
; CHECK: # %bb.0:
454+
; CHECK-NEXT: retq
455+
%ifma = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat(i64 1), <2 x i64> %x1)
456+
ret <2 x i64> %ifma
457+
}

0 commit comments

Comments
 (0)