@@ -619,7 +619,7 @@ namespace {
619619 SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620 const TargetLowering &TLI);
621621 SDValue foldPartialReduceMLAMulOp(SDNode *N);
622- SDValue foldPartialReduceMLANoMulOp (SDNode *N);
622+ SDValue foldPartialReduceAdd (SDNode *N);
623623
624624 SDValue CombineExtLoad(SDNode *N);
625625 SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12606,7 +12606,7 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1260612606SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1260712607 if (SDValue Res = foldPartialReduceMLAMulOp(N))
1260812608 return Res;
12609- if (SDValue Res = foldPartialReduceMLANoMulOp (N))
12609+ if (SDValue Res = foldPartialReduceAdd (N))
1261012610 return Res;
1261112611 return SDValue();
1261212612}
@@ -12682,11 +12682,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1268212682 RHSExtOp);
1268312683}
1268412684
12685- // Makes partial.reduce.umla(acc, zext(op1 ), splat(1)) into
12686- // partial.reduce.umla(acc, op, splat(trunc(1)))
12687- // Makes partial.reduce.smla(acc, sext(op1 ), splat(1)) into
12688- // partial.reduce.smla(acc, op, splat(trunc(1)))
12689- SDValue DAGCombiner::foldPartialReduceMLANoMulOp (SDNode *N) {
12685+ // partial.reduce.umla(acc, zext(op ), splat(1))
12686+ // -> partial.reduce.umla(acc, op, splat(trunc(1)))
12687+ // partial.reduce.smla(acc, sext(op ), splat(1))
12688+ // -> partial.reduce.smla(acc, op, splat(trunc(1)))
12689+ SDValue DAGCombiner::foldPartialReduceAdd (SDNode *N) {
1269012690 SDLoc DL(N);
1269112691 SDValue Acc = N->getOperand(0);
1269212692 SDValue Op1 = N->getOperand(1);
@@ -12703,25 +12703,20 @@ SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
1270312703
1270412704 SDValue UnextOp1 = Op1.getOperand(0);
1270512705 EVT UnextOp1VT = UnextOp1.getValueType();
12706-
1270712706 if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
1270812707 return SDValue();
1270912708
12710- SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12711-
1271212709 bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12713-
1271412710 bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
1271512711 EVT AccElemVT = Acc.getValueType().getVectorElementType();
1271612712 if (Op1IsSigned != NodeIsSigned &&
12717- (Op1.getValueType().getVectorElementType() != AccElemVT ||
12718- Op2.getValueType().getVectorElementType() != AccElemVT))
12713+ Op1.getValueType().getVectorElementType() != AccElemVT)
1271912714 return SDValue();
1272012715
1272112716 unsigned NewOpcode =
1272212717 Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
1272312718 return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12724- TruncOp2 );
12719+ DAG.getConstant(1, DL, UnextOp1VT) );
1272512720}
1272612721
1272712722SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
0 commit comments