@@ -3220,8 +3220,30 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
32203220void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA (SDNode *N, SDValue &Lo,
32213221 SDValue &Hi) {
32223222 SDLoc DL (N);
3223- SDValue Expanded = TLI.expandPartialReduceMLA (N, DAG);
3224- std::tie (Lo, Hi) = DAG.SplitVector (Expanded, DL);
3223+ SDValue Acc = N->getOperand (0 );
3224+ SDValue Input1 = N->getOperand (1 );
3225+ SDValue Input2 = N->getOperand (2 );
3226+
3227+ SDValue AccLo, AccHi;
3228+ std::tie (AccLo, AccHi) = DAG.SplitVector (Acc, DL);
3229+ unsigned Opcode = N->getOpcode ();
3230+
3231+ // If the input types don't need splitting, just accumulate into the
3232+ // low part of the accumulator.
3233+ if (getTypeAction (Input1.getValueType ()) != TargetLowering::TypeSplitVector) {
3234+ Lo = DAG.getNode (Opcode, DL, AccLo.getValueType (), AccLo, Input1, Input2);
3235+ Hi = AccHi;
3236+ return ;
3237+ }
3238+
3239+ SDValue Input1Lo, Input1Hi;
3240+ SDValue Input2Lo, Input2Hi;
3241+ std::tie (Input1Lo, Input1Hi) = DAG.SplitVector (Input1, DL);
3242+ std::tie (Input2Lo, Input2Hi) = DAG.SplitVector (Input2, DL);
3243+ EVT ResultVT = AccLo.getValueType ();
3244+
3245+ Lo = DAG.getNode (Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
3246+ Hi = DAG.getNode (Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
32253247}
32263248
32273249void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE (SDNode *N) {
@@ -4501,7 +4523,20 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
45014523}
45024524
45034525SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA (SDNode *N) {
4504- return TLI.expandPartialReduceMLA (N, DAG);
4526+ SDValue Acc = N->getOperand (0 );
4527+ assert (getTypeAction (Acc.getValueType ()) != TargetLowering::TypeSplitVector &&
4528+ " Accumulator should already be a legal type, and shouldn't need "
4529+ " further splitting" );
4530+
4531+ SDLoc DL (N);
4532+ SDValue Input1Lo, Input1Hi, Input2Lo, Input2Hi;
4533+ std::tie (Input1Lo, Input1Hi) = DAG.SplitVector (N->getOperand (1 ), DL);
4534+ std::tie (Input2Lo, Input2Hi) = DAG.SplitVector (N->getOperand (2 ), DL);
4535+ unsigned Opcode = N->getOpcode ();
4536+ EVT ResultVT = Acc.getValueType ();
4537+
4538+ SDValue Lo = DAG.getNode (Opcode, DL, ResultVT, Acc, Input1Lo, Input2Lo);
4539+ return DAG.getNode (Opcode, DL, ResultVT, Lo, Input1Hi, Input2Hi);
45054540}
45064541
45074542// ===----------------------------------------------------------------------===//
0 commit comments