@@ -21005,9 +21005,9 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
2100521005 if (N->getOpcode() != ISD::SMIN)
2100621006 return SDValue();
2100721007
21008- EVT VT = N->getValueType(0);
21008+ EVT DestVT = N->getValueType(0);
2100921009
21010- if (!VT .isVector() || VT .getScalarSizeInBits() > 64)
21010+ if (!DestVT .isVector() || DestVT .getScalarSizeInBits() > 64 || DestVT.isScalableVector() )
2101121011 return SDValue();
2101221012
2101321013 ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
@@ -21049,28 +21049,36 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
2104921049 SDValue SExt0 = Mul.getOperand(0);
2105021050 SDValue SExt1 = Mul.getOperand(1);
2105121051
21052+ if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
21053+ SExt1.getOpcode() != ISD::SIGN_EXTEND)
21054+ return SDValue();
21055+
2105221056 EVT SExt0Type = SExt0.getOperand(0).getValueType();
2105321057 EVT SExt1Type = SExt1.getOperand(0).getValueType();
2105421058
21055- if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
21056- SExt1.getOpcode() != ISD::SIGN_EXTEND || SExt0Type != SExt1Type ||
21059+ if (SExt0Type != SExt1Type ||
2105721060 SExt0Type.getScalarType() != ScalarType ||
2105821061 SExt0Type.getFixedSizeInBits() > 128)
2105921062 return SDValue();
2106021063
21061- // Source vectors with width < 64 are illegal and will need to be extended
21062- unsigned SourceVectorWidth = SExt0Type.getFixedSizeInBits();
21063- SDValue V0 = (SourceVectorWidth < 64) ? SExt0 : SExt0.getOperand(0);
21064- SDValue V1 = (SourceVectorWidth < 64) ? SExt1 : SExt1.getOperand(0);
21065-
2106621064 SDLoc DL(N);
21065+ SDValue V0 = SExt0.getOperand(0);
21066+ SDValue V1 = SExt1.getOperand(0);
21067+
21068+ // Ensure input vectors are extended to legal types
21069+ if (SExt0Type.getFixedSizeInBits() < 64) {
21070+ unsigned VecNumElements = SExt0Type.getVectorNumElements();
21071+ EVT ExtVecVT =
21072+ MVT::getVectorVT(MVT::getIntegerVT(64 / VecNumElements),
21073+ VecNumElements);
21074+ V0 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V0);
21075+ V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V1);
21076+ }
21077+
2106721078 SDValue SQDMULH =
2106821079 DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
21069- EVT DestVT = N->getValueType(0);
21070- if (DestVT.getScalarSizeInBits() > SExt0Type.getScalarSizeInBits())
21071- return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
2107221080
21073- return SQDMULH;
21081+ return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH) ;
2107421082}
2107521083
2107621084static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {
0 commit comments