Skip to content
50 changes: 50 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57966,6 +57966,51 @@ static SDValue pushAddIntoCmovOfConsts(SDNode *N, const SDLoc &DL,
Cmov.getOperand(3));
}

static SDValue matchIntegerMultiplyAdd(SDNode *N, SelectionDAG &DAG,
SDValue Op0, SDValue Op1,
const SDLoc &DL, EVT VT,
const X86Subtarget &Subtarget) {
using namespace SDPatternMatch;
if (!VT.isVector() || VT.getScalarType() != MVT::i64 ||
!Subtarget.hasAVX512() ||
(!Subtarget.hasAVXIFMA() && !Subtarget.hasIFMA()) ||
!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(X86ISD::VPMADD52L,
VT) ||
Op0.getValueType() != VT || Op1.getValueType() != VT)
return SDValue();

SDValue X, Y, Acc;
if (!sd_match(N, m_Add(m_Mul(m_Value(X), m_Value(Y)), m_Value(Acc))))
return SDValue();

auto CheckMulOperand = [&DAG, &VT](const SDValue &M, SDValue &Xval,
SDValue &Yval) -> bool {
if (M.getOpcode() != ISD::MUL)
return false;
const SDValue A = M.getOperand(0);
const SDValue B = M.getOperand(1);
const APInt Top12Set = APInt::getHighBitsSet(64, 12);
if (A.getValueType() != VT || B.getValueType() != VT ||
!DAG.MaskedValueIsZero(A, Top12Set) ||
!DAG.MaskedValueIsZero(B, Top12Set) ||
!DAG.MaskedValueIsZero(M, Top12Set))
return false;
Xval = A;
Yval = B;
return true;
};

if (CheckMulOperand(Op0, X, Y)) {
Acc = Op1;
} else if (CheckMulOperand(Op1, X, Y)) {
Acc = Op0;
} else {
return SDValue();
}

return DAG.getNode(X86ISD::VPMADD52L, DL, VT, Acc, X, Y);
}

static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
Expand Down Expand Up @@ -58069,6 +58114,11 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
Op0.getOperand(0), Op0.getOperand(2));
}

if (SDValue node =
matchIntegerMultiplyAdd(N, DAG, Op0, Op1, DL, VT, Subtarget)) {
return node;
}

return combineAddOrSubToADCOrSBB(N, DL, DAG);
}

Expand Down
111 changes: 111 additions & 0 deletions llvm/test/CodeGen/X86/ifma-combine-vpmadd52.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -O1 -mtriple=x86_64-unknown-unknown -mattr=+avx512dq | FileCheck %s --check-prefixes=X64

; 67108863 == (1 << 26) - 1

define dso_local <8 x i64> @test_512_combine_evex(<8 x i64> noundef %0, <8 x i64> noundef %1, <8 x i64> noundef %2) local_unnamed_addr #0 {
; X64-LABEL: test_512_combine_evex:
; X64: # %bb.0:
; X64-NEXT: vpbroadcastq {{.*#+}} zmm3 = [67108863,67108863,67108863,67108863,67108863,67108863,67108863,67108863]
; X64-NEXT: vpandq %zmm3, %zmm0, %zmm0
; X64-NEXT: vpandq %zmm3, %zmm1, %zmm1
; X64-NEXT: vpandq %zmm3, %zmm2, %zmm2
; X64-NOT: vpmul
; X64-NOT: vpadd
; X64-NEXT: vpmadd52luq %zmm1, %zmm2, %zmm0
; X64-NEXT: retq
%4 = and <8 x i64> %0, splat (i64 67108863)
%5 = and <8 x i64> %1, splat (i64 67108863)
%6 = and <8 x i64> %2, splat (i64 67108863)
%7 = mul nuw nsw <8 x i64> %5, %4
%8 = add nuw nsw <8 x i64> %7, %6
ret <8 x i64> %8
}

define dso_local <8 x i64> @fff(<8 x i64> noundef %0, <8 x i64> noundef %1, <8 x i64> noundef %2) local_unnamed_addr #0 {
%4 = and <8 x i64> %0, splat (i64 67108863)
%5 = and <8 x i64> %1, splat (i64 67108863)
%6 = and <8 x i64> %2, splat (i64 67108863)
%7 = mul nuw nsw <8 x i64> %5, %4
%8 = mul nuw nsw <8 x i64> %7, %6
%9 = add nuw nsw <8 x i64> %8, %7
ret <8 x i64> %9
}

define dso_local noundef <8 x i64> @test_512_no_combine_evex(<8 x i64> noundef %0, <8 x i64> noundef %1, <8 x i64> noundef %2) local_unnamed_addr #0 {
; X64-LABEL: test_512_no_combine_evex:
; X64: # %bb.0:
; X64-NOT: vpmadd52
; X64-NEXT: vpmullq %zmm0, %zmm1, %zmm0
; X64-NEXT: vpaddq %zmm2, %zmm0, %zmm0
; X64-NEXT: retq
%4 = mul <8 x i64> %1, %0
%5 = add <8 x i64> %4, %2
ret <8 x i64> %5
}

define dso_local <4 x i64> @test_256_combine_evex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #1 {
; X64-LABEL: test_256_combine_evex:
; X64: # %bb.0:
; X64-NEXT: vpbroadcastq {{.*#+}} ymm3 = [67108863,67108863,67108863,67108863]
; X64-NEXT: vpand %ymm3, %ymm0, %ymm0
; X64-NEXT: vpand %ymm3, %ymm1, %ymm1
; X64-NEXT: vpand %ymm3, %ymm2, %ymm2
; X64-NOT: vpmul
; X64-NOT: vpadd
; X64-NEXT: vpmadd52luq %ymm1, %ymm2, %ymm0
; X64-NEXT: retq
%4 = and <4 x i64> %0, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
%5 = and <4 x i64> %1, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
%6 = and <4 x i64> %2, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
%7 = mul nuw nsw <4 x i64> %5, %4
%8 = add nuw nsw <4 x i64> %7, %6
ret <4 x i64> %8
}

define dso_local noundef <4 x i64> @test_256_no_combine_evex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #1 {
; X64-LABEL: test_256_no_combine_evex:
; X64: # %bb.0:
; X64-NOT: vpmadd52
; X64-NEXT: vpmullq %ymm0, %ymm1, %ymm0
; X64-NEXT: vpaddq %ymm2, %ymm0, %ymm0
; X64-NEXT: retq
%4 = mul <4 x i64> %1, %0
%5 = add <4 x i64> %4, %2
ret <4 x i64> %5
}

define dso_local <4 x i64> @test_256_combine_vex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #2 {
; X64-LABEL: test_256_combine_vex:
; X64: # %bb.0:
; X64-NEXT: vpbroadcastq {{.*#+}} ymm3 = [67108863,67108863,67108863,67108863]
; X64-NEXT: vpand %ymm3, %ymm0, %ymm0
; X64-NEXT: vpand %ymm3, %ymm1, %ymm1
; X64-NEXT: vpand %ymm3, %ymm2, %ymm2
; X64-NOT: vpmul
; X64-NOT: vpadd
; X64-NEXT: {vex} vpmadd52luq %ymm1, %ymm2, %ymm0
; X64-NEXT: retq
%4 = and <4 x i64> %0, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
%5 = and <4 x i64> %1, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
%6 = and <4 x i64> %2, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
%7 = mul nuw nsw <4 x i64> %5, %4
%8 = add nuw nsw <4 x i64> %7, %6
ret <4 x i64> %8
}

define dso_local noundef <4 x i64> @test_256_no_combine_vex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #2 {
; X64-LABEL: test_256_no_combine_vex:
; X64: # %bb.0:
; X64-NOT: vpmadd52
; X64-NEXT: vpmullq %ymm0, %ymm1, %ymm0
; X64-NEXT: vpaddq %ymm2, %ymm0, %ymm0
; X64-NEXT: retq
%4 = mul <4 x i64> %1, %0
%5 = add <4 x i64> %4, %2
ret <4 x i64> %5
}

attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable "min-legal-vector-width"="512" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+avx,+avx2,+avx512dq,+avx512f,+avx512ifma,+cmov,+crc32,+cx8,+evex512,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" }
attributes #1 = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable "min-legal-vector-width"="256" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+avx,+avx2,+avx512dq,+avx512f,+avx512ifma,+avx512vl,+cmov,+crc32,+cx8,+evex512,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" }
attributes #2 = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable "min-legal-vector-width"="256" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+avx,+avx2,+avx512dq,+avx512f,+avx512vl,+avxifma,+cmov,+crc32,+cx8,+evex512,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" }