Skip to content
46 changes: 46 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57966,6 +57966,49 @@ static SDValue pushAddIntoCmovOfConsts(SDNode *N, const SDLoc &DL,
Cmov.getOperand(3));
}

// Attempt to turn ADD(MUL(x, y), acc)) -> VPMADD52L
// When upper 12 bits of x, y and MUL(x, y) are known to be 0
static SDValue matchVPMADD52(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
EVT VT, const X86Subtarget &Subtarget) {
using namespace SDPatternMatch;
if (!VT.isVector() || VT.getScalarSizeInBits() != 64 ||
(!Subtarget.hasAVXIFMA() && !Subtarget.hasIFMA()))
return SDValue();

// Need AVX-512VL vector length extensions if operating on XMM/YMM registers
if (!Subtarget.hasAVXIFMA() && !Subtarget.hasVLX() &&
VT.getSizeInBits() < 512)
return SDValue();

const auto TotalSize = VT.getSizeInBits();
if (TotalSize < 128 || !isPowerOf2_64(TotalSize))
return SDValue();

SDValue X, Y, Acc;
if (!sd_match(N, m_Add(m_Mul(m_Value(X), m_Value(Y)), m_Value(Acc))))
return SDValue();

KnownBits KnownX = DAG.computeKnownBits(X);
KnownBits KnownY = DAG.computeKnownBits(Y);
KnownBits KnownMul = KnownBits::mul(KnownX, KnownY);
if (KnownX.countMinLeadingZeros() < 12 ||
KnownY.countMinLeadingZeros() < 12 ||
KnownMul.countMinLeadingZeros() < 12)
return SDValue();

auto VPMADD52Builder = [](SelectionDAG &G, SDLoc DL,
ArrayRef<SDValue> SubOps) {
EVT SubVT = SubOps[0].getValueType();
assert(SubVT.getScalarSizeInBits() == 64 &&
"Unexpected element size, only supports 64bit size");
return G.getNode(X86ISD::VPMADD52L, DL, SubVT, SubOps[1] /*X*/,
SubOps[2] /*Y*/, SubOps[0] /*Acc*/);
};

return SplitOpsAndApply(DAG, Subtarget, DL, VT, {Acc, X, Y}, VPMADD52Builder,
/*CheckBWI*/ false);
}

static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
Expand Down Expand Up @@ -58069,6 +58112,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
Op0.getOperand(0), Op0.getOperand(2));
}

if (SDValue IFMA52 = matchVPMADD52(N, DAG, DL, VT, Subtarget))
return IFMA52;

return combineAddOrSubToADCOrSBB(N, DL, DAG);
}

Expand Down
Loading