@@ -45513,6 +45513,7 @@ static SDValue combinevXi1ConstantToInteger(SDValue Op, SelectionDAG &DAG) {
4551345513static SDValue combineCastedMaskArithmetic(SDNode *N, SelectionDAG &DAG,
4551445514 TargetLowering::DAGCombinerInfo &DCI,
4551545515 const X86Subtarget &Subtarget) {
45516+ using namespace SDPatternMatch;
4551645517 assert(N->getOpcode() == ISD::BITCAST && "Expected a bitcast");
4551745518
4551845519 if (!DCI.isBeforeLegalizeOps())
@@ -45526,34 +45527,25 @@ static SDValue combineCastedMaskArithmetic(SDNode *N, SelectionDAG &DAG,
4552645527 SDValue Op = N->getOperand(0);
4552745528 EVT SrcVT = Op.getValueType();
4552845529
45529- if (!Op.hasOneUse())
45530- return SDValue();
45531-
45532- // Look for logic ops.
45533- if (Op.getOpcode() != ISD::AND &&
45534- Op.getOpcode() != ISD::OR &&
45535- Op.getOpcode() != ISD::XOR)
45536- return SDValue();
45537-
4553845530 // Make sure we have a bitcast between mask registers and a scalar type.
4553945531 if (!(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
4554045532 DstVT.isScalarInteger()) &&
4554145533 !(DstVT.isVector() && DstVT.getVectorElementType() == MVT::i1 &&
4554245534 SrcVT.isScalarInteger()))
4554345535 return SDValue();
4554445536
45545- SDValue LHS = Op.getOperand(0);
45546- SDValue RHS = Op.getOperand(1);
45537+ SDValue LHS, RHS;
4554745538
45548- if (LHS.hasOneUse() && LHS.getOpcode() == ISD::BITCAST &&
45549- LHS.getOperand(0).getValueType() == DstVT)
45550- return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT, LHS.getOperand(0),
45551- DAG.getBitcast(DstVT, RHS));
45539+ // Look for logic ops.
45540+ if (!sd_match(Op, m_OneUse(m_BitwiseLogic(m_Value(LHS), m_Value(RHS)))))
45541+ return SDValue();
4555245542
45553- if (RHS.hasOneUse() && RHS.getOpcode() == ISD::BITCAST &&
45554- RHS.getOperand(0).getValueType() == DstVT)
45543+ // If either operand was bitcast from DstVT, then perform logic with DstVT (at
45544+ // least one of the getBitcast() will fold away).
45545+ if (sd_match(LHS, m_OneUse(m_BitCast(m_SpecificVT(DstVT)))) ||
45546+ sd_match(RHS, m_OneUse(m_BitCast(m_SpecificVT(DstVT)))))
4555545547 return DAG.getNode(Op.getOpcode(), SDLoc(N), DstVT,
45556- DAG.getBitcast(DstVT, LHS), RHS.getOperand(0 ));
45548+ DAG.getBitcast(DstVT, LHS), DAG.getBitcast(DstVT, RHS ));
4555745549
4555845550 // If the RHS is a vXi1 build vector, this is a good reason to flip too.
4555945551 // Most of these have to move a constant from the scalar domain anyway.
0 commit comments