@@ -12514,11 +12514,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1251412514 SDValue Op1 = N->getOperand(1);
1251512515 SDValue Op2 = N->getOperand(2);
1251612516
12517- if (Op1->getOpcode() != ISD::MUL)
12518- return SDValue();
12519-
1252012517 APInt ConstantOne;
12521- if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12518+ if (Op1->getOpcode() != ISD::MUL ||
12519+ !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
1252212520 !ConstantOne.isOne())
1252312521 return SDValue();
1252412522
@@ -12529,29 +12527,28 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1252912527 if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
1253012528 return SDValue();
1253112529
12530+ // For a 2-stage extend the signedness of both of the extends must be the
12531+ // same. This is so the node can be folded into only a signed or unsigned
12532+ // node.
1253212533 SDValue LHSExtOp = LHS->getOperand(0);
1253312534 SDValue RHSExtOp = RHS->getOperand(0);
1253412535 EVT LHSExtOpVT = LHSExtOp.getValueType();
12535- if (LHSExtOpVT != RHSExtOp.getValueType())
12536+ if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode )
1253612537 return SDValue();
1253712538
1253812539 // FIXME: Add a check to only perform the DAG combine if there is lowering
1253912540 // provided by the target
1254012541
12541- bool LHSIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12542- bool RHSIsSigned = RHSOpcode == ISD::SIGN_EXTEND;
12543- if (LHSIsSigned != RHSIsSigned)
12544- return SDValue();
12542+ bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
1254512543
1254612544 bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
1254712545 EVT AccElemVT = Acc.getValueType().getVectorElementType();
12548- if (LHSIsSigned != NodeIsSigned &&
12549- (Op1.getValueType().getVectorElementType() != AccElemVT ||
12550- Op2.getValueType().getVectorElementType() != AccElemVT))
12546+ if (ExtIsSigned != NodeIsSigned &&
12547+ Op1.getValueType().getVectorElementType() != AccElemVT)
1255112548 return SDValue();
1255212549
1255312550 unsigned NewOpcode =
12554- LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12551+ ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
1255512552 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
1255612553 RHSExtOp);
1255712554}
0 commit comments