@@ -12504,25 +12504,17 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1250412504}
1250512505
1250612506SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12507- // Only perform the DAG combine if there is custom lowering provided by the
12508- // target.
12509- if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
12510- N->getOperand(1).getValueType()))
12511- return SDValue();
12512-
12513- if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
12514- return Res;
12515- if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
12516- return Res;
12517- return SDValue();
12518- }
12519-
12520- SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
12521- // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
12522- // PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
12507+ // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
12508+ // Splat(1)) into
12509+ // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
12510+ // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
12511+ // Splat(1)) into
12512+ // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
1252312513 SDLoc DL(N);
1252412514
12515+ SDValue Op0 = N->getOperand(0);
1252512516 SDValue Op1 = N->getOperand(1);
12517+
1252612518 if (Op1->getOpcode() != ISD::MUL)
1252712519 return SDValue();
1252812520
@@ -12531,18 +12523,8 @@ SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
1253112523 !ConstantOne.isOne())
1253212524 return SDValue();
1253312525
12534- return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
12535- Op1->getOperand(0), Op1->getOperand(1));
12536- }
12537-
12538- SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
12539- // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
12540- // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
12541- // PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
12542- // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
12543- SDLoc DL(N);
12544- SDValue ExtMulOpLHS = N->getOperand(1);
12545- SDValue ExtMulOpRHS = N->getOperand(2);
12526+ SDValue ExtMulOpLHS = Op1->getOperand(0);
12527+ SDValue ExtMulOpRHS = Op1->getOperand(1);
1254612528 unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
1254712529 unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
1254812530 if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
@@ -12554,6 +12536,10 @@ SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
1255412536 EVT MulOpLHSVT = MulOpLHS.getValueType();
1255512537 if (MulOpLHSVT != MulOpRHS.getValueType())
1255612538 return SDValue();
12539+ // Only perform the DAG combine if there is custom lowering provided by the
12540+ // target
12541+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), MulOpLHSVT))
12542+ return SDValue();
1255712543
1255812544 bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
1255912545 bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
@@ -12562,8 +12548,8 @@ SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
1256212548
1256312549 unsigned NewOpcode =
1256412550 LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12565- return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0) ,
12566- MulOpLHS, MulOpRHS);
12551+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Op0, MulOpLHS ,
12552+ MulOpRHS);
1256712553}
1256812554
1256912555SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
0 commit comments