@@ -18177,16 +18177,38 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1817718177 unsigned ExtOpcode = Op0.getOpcode();
1817818178 SDValue A = Op0;
1817918179 SDValue B;
18180+ unsigned DotOpcode;
1818018181 if (ExtOpcode == ISD::MUL) {
1818118182 A = Op0.getOperand(0);
1818218183 B = Op0.getOperand(1);
18183- if (A.getOpcode() != B.getOpcode() ||
18184- A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
18184+ if (A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
1818518185 return SDValue();
18186- ExtOpcode = A.getOpcode();
18187- }
18188- if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
18186+ auto OpCodeA = A.getOpcode();
18187+ if (OpCodeA != ISD::ZERO_EXTEND && OpCodeA != ISD::SIGN_EXTEND)
18188+ return SDValue();
18189+
18190+ auto OpCodeB = B.getOpcode();
18191+ if (OpCodeB != ISD::ZERO_EXTEND && OpCodeB != ISD::SIGN_EXTEND)
18192+ return SDValue();
18193+
18194+ if (OpCodeA == OpCodeB) {
18195+ DotOpcode =
18196+ OpCodeA == ISD::ZERO_EXTEND ? AArch64ISD::UDOT : AArch64ISD::SDOT;
18197+ } else {
18198+ // Check USDOT support support
18199+ if (!ST->hasMatMulInt8())
18200+ return SDValue();
18201+ DotOpcode = AArch64ISD::USDOT;
18202+ if (OpCodeA == ISD::SIGN_EXTEND)
18203+ std::swap(A, B);
18204+ }
18205+ } else if (ExtOpcode == ISD::ZERO_EXTEND) {
18206+ DotOpcode = AArch64ISD::UDOT;
18207+ } else if (ExtOpcode == ISD::SIGN_EXTEND) {
18208+ DotOpcode = AArch64ISD::SDOT;
18209+ } else {
1818918210 return SDValue();
18211+ }
1819018212
1819118213 EVT Op0VT = A.getOperand(0).getValueType();
1819218214 bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
@@ -18212,8 +18234,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
1821218234 NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
1821318235 TargetType = MVT::v2i32;
1821418236 }
18215- auto DotOpcode =
18216- (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
1821718237 // Handle the case where we need to generate only one Dot operation.
1821818238 if (NumOfVecReduce == 1) {
1821918239 SDValue Zeros = DAG.getConstant(0, DL, TargetType);
0 commit comments