Skip to content

Commit c241eb3

Browse files
authored
[X86] Fold X * Y + Z --> C + Z for vpmadd52l/vpmadd52h (#156293)
Address TODO and implement constant fold for intermediate multiplication result of vpmadd52l/vpmadd52h.
1 parent 37127f7 commit c241eb3

File tree

2 files changed

+120
-5
lines changed

2 files changed

+120
-5
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44969,11 +44969,19 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4496944969
TLO, Depth + 1))
4497044970
return true;
4497144971

44972-
// X * 0 + Y --> Y
44973-
// TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known
44974-
// zeroes.
44975-
if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero())
44976-
return TLO.CombineTo(Op, Op2);
44972+
KnownBits KnownMul;
44973+
KnownOp0 = KnownOp0.trunc(52);
44974+
KnownOp1 = KnownOp1.trunc(52);
44975+
KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
44976+
: KnownBits::mulhu(KnownOp0, KnownOp1);
44977+
KnownMul = KnownMul.zext(64);
44978+
44979+
// lo/hi(X * Y) + Z --> C + Z
44980+
if (KnownMul.isConstant()) {
44981+
SDLoc DL(Op);
44982+
SDValue C = TLO.DAG.getConstant(KnownMul.getConstant(), DL, VT);
44983+
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
44984+
}
4497744985

4497844986
// TODO: Compute the known bits for VPMADD52L/VPMADD52H.
4497944987
break;

llvm/test/CodeGen/X86/combine-vpmadd52.ll

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,110 @@ define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) {
183183
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> <i64 0, i64 123>, <2 x i64> %x1)
184184
ret <2 x i64> %1
185185
}
186+
187+
; (1 << 51) * (1 << 1) -> 1 << 52 -> low 52 bits are zeroes
188+
define <2 x i64> @test_vpmadd52l_mul_lo52_zero(<2 x i64> %x0) {
189+
; CHECK-LABEL: test_vpmadd52l_mul_lo52_zero:
190+
; CHECK: # %bb.0:
191+
; CHECK-NEXT: retq
192+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2))
193+
ret <2 x i64> %1
194+
}
195+
196+
; (1 << 25) * (1 << 26) = 1 << 51 -> high 52 bits are zeroes
197+
define <2 x i64> @test_vpmadd52h_mul_hi52_zero(<2 x i64> %x0) {
198+
; CHECK-LABEL: test_vpmadd52h_mul_hi52_zero:
199+
; CHECK: # %bb.0:
200+
; CHECK-NEXT: retq
201+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 33554432), <2 x i64> splat (i64 67108864))
202+
ret <2 x i64> %1
203+
}
204+
205+
define <2 x i64> @test_vpmadd52l_mul_lo52_const(<2 x i64> %x0) {
206+
; AVX512-LABEL: test_vpmadd52l_mul_lo52_const:
207+
; AVX512: # %bb.0:
208+
; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
209+
; AVX512-NEXT: retq
210+
;
211+
; AVX-LABEL: test_vpmadd52l_mul_lo52_const:
212+
; AVX: # %bb.0:
213+
; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
214+
; AVX-NEXT: retq
215+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 123), <2 x i64> splat (i64 456))
216+
ret <2 x i64> %1
217+
}
218+
219+
; (1 << 51) * (1 << 51) -> 1 << 102 -> the high 52 bits is 1 << 50
220+
define <2 x i64> @test_vpmadd52h_mul_hi52_const(<2 x i64> %x0) {
221+
; AVX512-LABEL: test_vpmadd52h_mul_hi52_const:
222+
; AVX512: # %bb.0:
223+
; AVX512-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm0
224+
; AVX512-NEXT: retq
225+
;
226+
; AVX-LABEL: test_vpmadd52h_mul_hi52_const:
227+
; AVX: # %bb.0:
228+
; AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
229+
; AVX-NEXT: retq
230+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 2251799813685248), <2 x i64> splat (i64 2251799813685248))
231+
ret <2 x i64> %1
232+
}
233+
234+
define <2 x i64> @test_vpmadd52l_mul_lo52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
235+
; CHECK-LABEL: test_vpmadd52l_mul_lo52_mask:
236+
; CHECK: # %bb.0:
237+
; CHECK-NEXT: retq
238+
%and1 = and <2 x i64> %x0, splat (i64 1073741824) ; 1LL << 30
239+
%and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
240+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
241+
ret <2 x i64> %1
242+
}
243+
244+
define <2 x i64> @test_vpmadd52h_mul_hi52_mask(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
245+
; CHECK-LABEL: test_vpmadd52h_mul_hi52_mask:
246+
; CHECK: # %bb.0:
247+
; CHECK-NEXT: retq
248+
%and1 = lshr <2 x i64> %x0, splat (i64 40)
249+
%and2 = lshr <2 x i64> %x1, splat (i64 40)
250+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
251+
ret <2 x i64> %1
252+
}
253+
254+
define <2 x i64> @test_vpmadd52l_mul_lo52_mask_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
255+
; AVX512-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
256+
; AVX512: # %bb.0:
257+
; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm0, %xmm2
258+
; AVX512-NEXT: vpandq {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to2}, %xmm1, %xmm1
259+
; AVX512-NEXT: vpmadd52luq %xmm1, %xmm2, %xmm0
260+
; AVX512-NEXT: retq
261+
;
262+
; AVX-LABEL: test_vpmadd52l_mul_lo52_mask_negative:
263+
; AVX: # %bb.0:
264+
; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2
265+
; AVX-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
266+
; AVX-NEXT: {vex} vpmadd52luq %xmm1, %xmm2, %xmm0
267+
; AVX-NEXT: retq
268+
%and1 = and <2 x i64> %x0, splat (i64 2097152) ; 1LL << 21
269+
%and2 = and <2 x i64> %x1, splat (i64 1073741824) ; 1LL << 30
270+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
271+
ret <2 x i64> %1
272+
}
273+
274+
define <2 x i64> @test_vpmadd52h_mul_hi52_negative(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) {
275+
; AVX512-LABEL: test_vpmadd52h_mul_hi52_negative:
276+
; AVX512: # %bb.0:
277+
; AVX512-NEXT: vpsrlq $30, %xmm0, %xmm2
278+
; AVX512-NEXT: vpsrlq $43, %xmm1, %xmm1
279+
; AVX512-NEXT: vpmadd52huq %xmm1, %xmm2, %xmm0
280+
; AVX512-NEXT: retq
281+
;
282+
; AVX-LABEL: test_vpmadd52h_mul_hi52_negative:
283+
; AVX: # %bb.0:
284+
; AVX-NEXT: vpsrlq $30, %xmm0, %xmm2
285+
; AVX-NEXT: vpsrlq $43, %xmm1, %xmm1
286+
; AVX-NEXT: {vex} vpmadd52huq %xmm1, %xmm2, %xmm0
287+
; AVX-NEXT: retq
288+
%and1 = lshr <2 x i64> %x0, splat (i64 30)
289+
%and2 = lshr <2 x i64> %x1, splat (i64 43)
290+
%1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and1, <2 x i64> %and2)
291+
ret <2 x i64> %1
292+
}

0 commit comments

Comments
 (0)