@@ -11894,23 +11894,17 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1189411894 EVT ExtMulOpVT =
1189511895 EVT::getVectorVT (*DAG.getContext (), AccVT.getVectorElementType (),
1189611896 MulOpVT.getVectorElementCount ());
11897+
11898+ unsigned ExtOpcLHS = N->getOpcode () == ISD::PARTIAL_REDUCE_UMLA
11899+ ? ISD::ZERO_EXTEND
11900+ : ISD::SIGN_EXTEND;
11901+ unsigned ExtOpcRHS = N->getOpcode () == ISD::PARTIAL_REDUCE_SMLA
11902+ ? ISD::SIGN_EXTEND
11903+ : ISD::ZERO_EXTEND;
11904+
1189711905 if (ExtMulOpVT != MulOpVT) {
11898- switch (N->getOpcode ()) {
11899- case ISD::PARTIAL_REDUCE_SMLA:
11900- MulLHS = DAG.getNode (ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
11901- MulRHS = DAG.getNode (ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulRHS);
11902- break ;
11903- case ISD::PARTIAL_REDUCE_UMLA:
11904- MulLHS = DAG.getNode (ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulLHS);
11905- MulRHS = DAG.getNode (ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
11906- break ;
11907- case ISD::PARTIAL_REDUCE_SUMLA:
11908- MulLHS = DAG.getNode (ISD::SIGN_EXTEND, DL, ExtMulOpVT, MulLHS);
11909- MulRHS = DAG.getNode (ISD::ZERO_EXTEND, DL, ExtMulOpVT, MulRHS);
11910- break ;
11911- default :
11912- llvm_unreachable (" unexpected opcode" );
11913- }
11906+ MulLHS = DAG.getNode (ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
11907+ MulRHS = DAG.getNode (ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
1191411908 }
1191511909 SDValue Input = MulLHS;
1191611910 APInt ConstantOne;
0 commit comments