@@ -11900,22 +11900,22 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1190011900 EVT ReducedTy = Acc.getValueType ();
1190111901 EVT FullTy = MulLHS.getValueType ();
1190211902
11903- auto ExtendToAccEltVT = [&](SDValue V) {
11904- unsigned ExtOpc = V->getOpcode () == ISD::PARTIAL_REDUCE_SMLA
11905- ? ISD::SIGN_EXTEND
11906- : ISD::ZERO_EXTEND;
11907- EVT ExtVT = V.getValueType ().changeVectorElementType (
11908- Acc.getValueType ().getVectorElementType ());
11909- if (ExtVT != FullTy)
11910- return DAG.getNode (ExtOpc, DL, ExtVT, V);
11911- return V;
11912- };
11913-
1191411903 EVT NewVT =
1191511904 EVT::getVectorVT (*DAG.getContext (), ReducedTy.getVectorElementType (),
1191611905 FullTy.getVectorElementCount ());
11917- MulLHS = ExtendToAccEltVT (MulLHS);
11918- MulRHS = ExtendToAccEltVT (MulRHS);
11906+ unsigned ExtOpc = N->getOpcode () == ISD::PARTIAL_REDUCE_SMLA
11907+ ? ISD::SIGN_EXTEND
11908+ : ISD::ZERO_EXTEND;
11909+ EVT MulLHSVT = MulLHS.getValueType ();
11910+ assert (MulLHSVT == MulRHS.getValueType () &&
11911+ " The second and third operands of a PARTIAL_REDUCE_MLA node must have "
11912+ " the same value type!" );
11913+ EVT ExtVT = MulLHSVT.changeVectorElementType (
11914+ Acc.getValueType ().getVectorElementType ());
11915+ if (ExtVT != FullTy) {
11916+ MulLHS = DAG.getNode (ExtOpc, DL, ExtVT, MulLHS);
11917+ MulRHS = DAG.getNode (ExtOpc, DL, ExtVT, MulRHS);
11918+ }
1191911919 SDValue Input = MulLHS;
1192011920 APInt ConstantOne;
1192111921 if (!ISD::isConstantSplatVector (MulRHS.getNode (), ConstantOne) ||
0 commit comments