@@ -18177,17 +18177,20 @@ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
1817718177 assert(VT == Op1.getSimpleValueType() &&
1817818178 VT.getVectorElementType() == MVT::i32);
1817918179
18180- assert(VT.isFixedLengthVector());
18181- MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18182- SDValue Passthru = convertToScalableVector(
18183- ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
18184- Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18185- Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18186-
18180+ SDValue Passthru = DAG.getConstant(0, DL, VT);
18181+ MVT ContainerVT = VT;
18182+ if (VT.isFixedLengthVector()) {
18183+ ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18184+ Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
18185+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18186+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18187+ }
1818718188 auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
1818818189 SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
1818918190 {Op0, Op1, Passthru, Mask, VL});
18190- return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18191+ if (VT.isFixedLengthVector())
18192+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18193+ return LocalAccum;
1819118194}
1819218195
1819318196static MVT getQDOTXResultType(MVT OpVT) {
@@ -18207,7 +18210,7 @@ static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
1820718210 EVT AVT = A.getValueType();
1820818211 EVT BVT = B.getValueType();
1820918212 assert(AVT.getVectorElementType() == BVT.getVectorElementType());
18210- if (AVT.getVectorNumElements () > BVT.getVectorNumElements ()) {
18213+ if (AVT.getVectorMinNumElements () > BVT.getVectorMinNumElements ()) {
1821118214 std::swap(A, B);
1821218215 std::swap(AVT, BVT);
1821318216 }
@@ -18641,17 +18644,19 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
1864118644static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
1864218645 const RISCVSubtarget &Subtarget) {
1864318646
18644- assert(N->getOpcode() == RISCVISD::ADD_VL);
18647+ assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD );
1864518648
1864618649 if (!N->getValueType(0).isVector())
1864718650 return SDValue();
1864818651
1864918652 SDValue Addend = N->getOperand(0);
1865018653 SDValue DotOp = N->getOperand(1);
1865118654
18652- SDValue AddPassthruOp = N->getOperand(2);
18653- if (!AddPassthruOp.isUndef())
18654- return SDValue();
18655+ if (N->getOpcode() == RISCVISD::ADD_VL) {
18656+ SDValue AddPassthruOp = N->getOperand(2);
18657+ if (!AddPassthruOp.isUndef())
18658+ return SDValue();
18659+ }
1865518660
1865618661 auto IsVqdotqOpc = [](unsigned Opc) {
1865718662 switch (Opc) {
@@ -18670,8 +18675,15 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
1867018675 if (!IsVqdotqOpc(DotOp.getOpcode()))
1867118676 return SDValue();
1867218677
18673- SDValue AddMask = N->getOperand(3);
18674- SDValue AddVL = N->getOperand(4);
18678+ auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
18679+ const RISCVSubtarget &Subtarget) {
18680+ if (N->getOpcode() == ISD::ADD) {
18681+ SDLoc DL(N);
18682+ return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
18683+ Subtarget);
18684+ }
18685+ return std::make_pair(N->getOperand(3), N->getOperand(4));
18686+ }(N, DAG, Subtarget);
1867518687
1867618688 SDValue MulVL = DotOp.getOperand(4);
1867718689 if (AddVL != MulVL)
@@ -19309,6 +19321,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1930919321 return V;
1931019322 if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
1931119323 return V;
19324+ if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
19325+ return V;
1931219326 return performADDCombine(N, DCI, Subtarget);
1931319327 }
1931419328 case ISD::SUB: {
0 commit comments