diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index d78cf00a5a2fc..814a4bd1df714 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44969,11 +44969,19 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( TLO, Depth + 1)) return true; - // X * 0 + Y --> Y - // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known - // zeroes. - if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero()) - return TLO.CombineTo(Op, Op2); + KnownBits KnownMul; + KnownOp0 = KnownOp0.trunc(52); + KnownOp1 = KnownOp1.trunc(52); + KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1) + : KnownBits::mulhu(KnownOp0, KnownOp1); + KnownMul = KnownMul.zext(64); + + // lo/hi(X * Y) + Z --> C + Z + if (KnownMul.isConstant()) { + SDLoc DL(Op); + SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT); + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2)); + } // TODO: Compute the known bits for VPMADD52L/VPMADD52H. break; diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll index fd295ea31c55c..9afc1119267ec 100644 --- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll +++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll @@ -183,3 +183,110 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) { %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> , <2 x i64> %x1) ret <2 x i64> %1 } + +; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes +define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) { +; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2)) + ret <2 x i64> %1 +} + +; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes +define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) { +; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864)) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) { +; AVX512-LABEL: test_vpmadd52l_mul_lo52_const: +; AVX512: # %bb.0: +; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0 +; AVX512-NEXT: retq +; +; AVX-LABEL: test_vpmadd52l_mul_lo52_const: +; AVX: # %bb.0: +; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; AVX-NEXT: retq + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456)) + ret <2 x i64> %1 +} + +; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50 +define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) { +; AVX512-LABEL: test_vpmadd52h_mul_hi52_const: +; AVX512: # %bb.0: +; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0 +; AVX512-NEXT: retq +; +; AVX-LABEL: test_vpmadd52h_mul_hi52_const: +; AVX: # %bb.0: +; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 +; AVX-NEXT: retq + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248)) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) { +; CHECK-LABEL: test_vpmadd52l_mul_lo52_mask: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + %and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30 + %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30 + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) { +; CHECK-LABEL: test_vpmadd52h_mul_hi52_mask: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + %and1 = lshr <2 x i64> %x0, splat (i64 40) + %and2 = lshr <2 x i64> %x1, splat (i64 40) + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_mul_lo52_mask_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) { +; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask_negative: +; AVX512: # %bb.0: +; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm2 +; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm1 +; AVX512-NEXT: vpmadd52luq %xmm1, %xmm2, %xmm0 +; AVX512-NEXT: retq +; +; AVX-LABEL: test_vpmadd52l_mul_lo52_mask_negative: +; AVX: # %bb.0: +; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2 +; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1 +; AVX-NEXT: {vex} vpmadd52luq %xmm1, %xmm2, %xmm0 +; AVX-NEXT: retq + %and1 = and <2 x i64> %x0, splat (i64 2097152) ; 1LL << 21 + %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30 + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52h_mul_hi52_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) { +; AVX512-LABEL: test_vpmadd52h_mul_hi52_negative: +; AVX512: # %bb.0: +; AVX512-NEXT: vpsrlq $30, %xmm0, %xmm2 +; AVX512-NEXT: vpsrlq $43, %xmm1, %xmm1 +; AVX512-NEXT: vpmadd52huq %xmm1, %xmm2, %xmm0 +; AVX512-NEXT: retq +; +; AVX-LABEL: test_vpmadd52h_mul_hi52_negative: +; AVX: # %bb.0: +; AVX-NEXT: vpsrlq $30, %xmm0, %xmm2 +; AVX-NEXT: vpsrlq $43, %xmm1, %xmm1 +; AVX-NEXT: {vex} vpmadd52huq %xmm1, %xmm2, %xmm0 +; AVX-NEXT: retq + %and1 = lshr <2 x i64> %x0, splat (i64 30) + %and2 = lshr <2 x i64> %x1, splat (i64 43) + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2) + ret <2 x i64> %1 +}