@@ -56447,9 +56447,11 @@ static SDValue combineADC(SDNode *N, SelectionDAG &DAG,
5644756447 return SDValue();
5644856448}
5644956449
56450- static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1 ,
56450+ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N ,
5645156451 const SDLoc &DL, EVT VT,
5645256452 const X86Subtarget &Subtarget) {
56453+ using namespace SDPatternMatch;
56454+
5645356455 // Example of pattern we try to detect:
5645456456 // t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1))))
5645556457 //(add (build_vector (extract_elt t, 0),
@@ -56464,15 +56466,16 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
5646456466 if (!Subtarget.hasSSE2())
5646556467 return SDValue();
5646656468
56467- if (Op0.getOpcode() != ISD::BUILD_VECTOR ||
56468- Op1.getOpcode() != ISD::BUILD_VECTOR)
56469- return SDValue();
56470-
5647156469 if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
5647256470 VT.getVectorNumElements() < 4 ||
5647356471 !isPowerOf2_32(VT.getVectorNumElements()))
5647456472 return SDValue();
5647556473
56474+ SDValue Op0, Op1;
56475+ if (!sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
56476+ m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))))
56477+ return SDValue();
56478+
5647656479 // Check if one of Op0,Op1 is of the form:
5647756480 // (build_vector (extract_elt Mul, 0),
5647856481 // (extract_elt Mul, 2),
@@ -56489,26 +56492,23 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
5648956492 SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
5649056493 Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1);
5649156494 // TODO: Be more tolerant to undefs.
56492- if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56493- Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56494- Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56495- Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
56496- return SDValue();
56497- auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1));
56498- auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1));
56499- auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1));
56500- auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1));
56501- if (!Const0L || !Const1L || !Const0H || !Const1H)
56495+ APInt Idx0L, Idx0H, Idx1L, Idx1H;
56496+ if (!sd_match(Op0L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56497+ m_ConstInt(Idx0L))) ||
56498+ !sd_match(Op0H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56499+ m_ConstInt(Idx0H))) ||
56500+ !sd_match(Op1L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56501+ m_ConstInt(Idx1L))) ||
56502+ !sd_match(Op1H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
56503+ m_ConstInt(Idx1H))))
5650256504 return SDValue();
56503- unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(),
56504- Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue();
5650556505 // Commutativity of mul allows factors of a product to reorder.
56506- if (Idx0L > Idx1L)
56506+ if (Idx0L.getZExtValue() > Idx1L.getZExtValue() )
5650756507 std::swap(Idx0L, Idx1L);
56508- if (Idx0H > Idx1H)
56508+ if (Idx0H.getZExtValue() > Idx1H.getZExtValue() )
5650956509 std::swap(Idx0H, Idx1H);
5651056510 // Commutativity of add allows pairs of factors to reorder.
56511- if (Idx0L > Idx0H) {
56511+ if (Idx0L.getZExtValue() > Idx0H.getZExtValue() ) {
5651256512 std::swap(Idx0L, Idx0H);
5651356513 std::swap(Idx1L, Idx1H);
5651456514 }
@@ -56555,39 +56555,26 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
5655556555// Attempt to turn this pattern into PMADDWD.
5655656556// (add (mul (sext (build_vector)), (sext (build_vector))),
5655756557// (mul (sext (build_vector)), (sext (build_vector)))
56558- static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1 ,
56558+ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDNode *N ,
5655956559 const SDLoc &DL, EVT VT,
5656056560 const X86Subtarget &Subtarget) {
56561- if (!Subtarget.hasSSE2())
56562- return SDValue();
56561+ using namespace SDPatternMatch;
5656356562
56564- if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL )
56563+ if (!Subtarget.hasSSE2() )
5656556564 return SDValue();
5656656565
5656756566 if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
5656856567 VT.getVectorNumElements() < 4 ||
5656956568 !isPowerOf2_32(VT.getVectorNumElements()))
5657056569 return SDValue();
5657156570
56572- SDValue N00 = N0.getOperand(0);
56573- SDValue N01 = N0.getOperand(1);
56574- SDValue N10 = N1.getOperand(0);
56575- SDValue N11 = N1.getOperand(1);
56576-
5657756571 // All inputs need to be sign extends.
5657856572 // TODO: Support ZERO_EXTEND from known positive?
56579- if (N00.getOpcode() != ISD::SIGN_EXTEND ||
56580- N01.getOpcode() != ISD::SIGN_EXTEND ||
56581- N10.getOpcode() != ISD::SIGN_EXTEND ||
56582- N11.getOpcode() != ISD::SIGN_EXTEND)
56573+ SDValue N00, N01, N10, N11;
56574+ if (!sd_match(N, m_Add(m_Mul(m_SExt(m_Value(N00)), m_SExt(m_Value(N01))),
56575+ m_Mul(m_SExt(m_Value(N10)), m_SExt(m_Value(N11))))))
5658356576 return SDValue();
5658456577
56585- // Peek through the extends.
56586- N00 = N00.getOperand(0);
56587- N01 = N01.getOperand(0);
56588- N10 = N10.getOperand(0);
56589- N11 = N11.getOperand(0);
56590-
5659156578 // Must be extending from vXi16.
5659256579 EVT InVT = N00.getValueType();
5659356580 if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
@@ -56614,34 +56601,26 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
5661456601 SDValue N10Elt = N10.getOperand(i);
5661556602 SDValue N11Elt = N11.getOperand(i);
5661656603 // TODO: Be more tolerant to undefs.
56617- if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56618- N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56619- N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
56620- N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
56621- return SDValue();
56622- auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
56623- auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
56624- auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
56625- auto *ConstN11Elt = dyn_cast<ConstantSDNode> (N11Elt.getOperand(1));
56626- if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt )
56604+ SDValue N00In, N01In, N10In, N11In;
56605+ APInt IdxN00, IdxN01, IdxN10, IdxN11;
56606+ if (!sd_match(N00Elt, m_BinOp( ISD::EXTRACT_VECTOR_ELT, m_Value(N00In),
56607+ m_ConstInt(IdxN00))) ||
56608+ !sd_match(N01Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N01In),
56609+ m_ConstInt(IdxN01))) ||
56610+ !sd_match(N10Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N10In),
56611+ m_ConstInt(IdxN10))) ||
56612+ !sd_match (N11Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N11In),
56613+ m_ConstInt(IdxN11))) )
5662756614 return SDValue();
56628- unsigned IdxN00 = ConstN00Elt->getZExtValue();
56629- unsigned IdxN01 = ConstN01Elt->getZExtValue();
56630- unsigned IdxN10 = ConstN10Elt->getZExtValue();
56631- unsigned IdxN11 = ConstN11Elt->getZExtValue();
5663256615 // Add is commutative so indices can be reordered.
56633- if (IdxN00 > IdxN10) {
56616+ if (IdxN00.getZExtValue() > IdxN10.getZExtValue() ) {
5663456617 std::swap(IdxN00, IdxN10);
5663556618 std::swap(IdxN01, IdxN11);
5663656619 }
5663756620 // N0 indices be the even element. N1 indices must be the next odd element.
56638- if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
56639- IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
56621+ if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || IdxN01 != 2 * i ||
56622+ IdxN11 != 2 * i + 1)
5664056623 return SDValue();
56641- SDValue N00In = N00Elt.getOperand(0);
56642- SDValue N01In = N01Elt.getOperand(0);
56643- SDValue N10In = N10Elt.getOperand(0);
56644- SDValue N11In = N11Elt.getOperand(0);
5664556624
5664656625 // First time we find an input capture it.
5664756626 if (!In0) {
@@ -56815,9 +56794,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
5681556794 if (SDValue Select = pushAddIntoCmovOfConsts(N, DL, DAG, Subtarget))
5681656795 return Select;
5681756796
56818- if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1 , DL, VT, Subtarget))
56797+ if (SDValue MAdd = matchPMADDWD(DAG, N , DL, VT, Subtarget))
5681956798 return MAdd;
56820- if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1 , DL, VT, Subtarget))
56799+ if (SDValue MAdd = matchPMADDWD_2(DAG, N , DL, VT, Subtarget))
5682156800 return MAdd;
5682256801 if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT))
5682356802 return MAdd;
0 commit comments