-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h #156293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-backend-x86 Author: XChy (XChy) ChangesAddress TODO and implement constant fold for intermediate multiplication result of vpmadd52l/vpmadd52h. Full diff: https://github.com/llvm/llvm-project/pull/156293.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index d78cf00a5a2fc..840c2730625c0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44954,26 +44954,39 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
}
case X86ISD::VPMADD52L:
case X86ISD::VPMADD52H: {
- KnownBits KnownOp0, KnownOp1;
+ KnownBits Known52BitsOfOp0, Known52BitsOfOp1;
SDValue Op0 = Op.getOperand(0);
SDValue Op1 = Op.getOperand(1);
SDValue Op2 = Op.getOperand(2);
// Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
// operand 2).
APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
- if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
- TLO, Depth + 1))
+ if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts,
+ Known52BitsOfOp0, TLO, Depth + 1))
return true;
- if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
- TLO, Depth + 1))
+ if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts,
+ Known52BitsOfOp1, TLO, Depth + 1))
return true;
- // X * 0 + Y --> Y
- // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known
- // zeroes.
- if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero())
- return TLO.CombineTo(Op, Op2);
+ KnownBits KnownMul;
+ Known52BitsOfOp0 = Known52BitsOfOp0.trunc(52);
+ Known52BitsOfOp1 = Known52BitsOfOp1.trunc(52);
+ if (Opc == X86ISD::VPMADD52L) {
+ KnownMul =
+ KnownBits::mul(Known52BitsOfOp0.zext(104), Known52BitsOfOp1.zext(104))
+ .trunc(52);
+ } else {
+ KnownMul = KnownBits::mulhu(Known52BitsOfOp0, Known52BitsOfOp1);
+ }
+ KnownMul = KnownMul.zext(64);
+
+ // C1 * C2 + Z --> C3 + Z
+ if (KnownMul.isConstant()) {
+ SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), SDLoc(Op0), VT);
+ return TLO.CombineTo(Op,
+ TLO.DAG.getNode(ISD::ADD, SDLoc(Op), VT, C, Op2));
+ }
// TODO: Compute the known bits for VPMADD52L/VPMADD52H.
break;
diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
index fd295ea31c55c..1e075bfe12a31 100644
--- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll
+++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll
@@ -183,3 +183,110 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> <i64 0, i64 123>, <2 x i64> %x1)
ret <2 x i64> %1
}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
+ ; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2))
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
+ ; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864))
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_const:
+; AVX: # %bb.0:
+; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456))
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
+ ; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_const:
+; AVX: # %bb.0:
+; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX-NEXT: retq
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248))
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52l_mul_lo52_mask:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30
+ %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; CHECK-LABEL: test_vpmadd52h_mul_hi52_mask:
+; CHECK: # %bb.0:
+; CHECK-NEXT: retq
+ %and1 = lshr <2 x i64> %x0, splat (i64 40)
+ %and2 = lshr <2 x i64> %x1, splat (i64 40)
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52l_mul_lo52_mask_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm2
+; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm1
+; AVX512-NEXT: vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
+; AVX: # %bb.0:
+; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2
+; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; AVX-NEXT: {vex} vpmadd52luq %xmm1, %xmm2, %xmm0
+; AVX-NEXT: retq
+ %and1 = and <2 x i64> %x0, splat (i64 2097152) ; 1LL << 21
+ %and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
+
+define <2 x i64> @test_vpmadd52h_mul_hi52_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
+; AVX512-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX512: # %bb.0:
+; AVX512-NEXT: vpsrlq $30, %xmm0, %xmm2
+; AVX512-NEXT: vpsrlq $43, %xmm1, %xmm1
+; AVX512-NEXT: vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX512-NEXT: retq
+;
+; AVX-LABEL: test_vpmadd52h_mul_hi52_negative:
+; AVX: # %bb.0:
+; AVX-NEXT: vpsrlq $30, %xmm0, %xmm2
+; AVX-NEXT: vpsrlq $43, %xmm1, %xmm1
+; AVX-NEXT: {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
+; AVX-NEXT: retq
+ %and1 = lshr <2 x i64> %x0, splat (i64 30)
+ %and2 = lshr <2 x i64> %x1, splat (i64 43)
+ %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
+ ret <2 x i64> %1
+}
|
| TLO.DAG.getNode(ISD::ADD, SDLoc(Op), VT, C, Op2)); | ||
| } | ||
|
|
||
| // TODO: Compute the known bits for VPMADD52L/VPMADD52H. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Easy to resolve now.
RKSimon
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one minor - cheers
| } | ||
|
|
||
| define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) { | ||
| ; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(style) move these manual comments above the define to avoid the update script potentially mangling it
56cc479 to
880d0ca
Compare
Address TODO and implement constant fold for intermediate multiplication result of vpmadd52l/vpmadd52h.