@@ -1578,6 +1578,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15781578 setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
15791579 setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
15801580 setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
1581+
1582+ if (Subtarget.useRVVForFixedLengthVectors()) {
1583+ for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
1584+ if (VT.getVectorElementType() != MVT::i32 ||
1585+ !useRVVForFixedLengthVectorVT(VT))
1586+ continue;
1587+ ElementCount EC = VT.getVectorElementCount();
1588+ MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
1589+ setPartialReduceMLAAction(VT, ArgVT, Custom);
1590+ }
1591+ }
15811592 }
15821593
15831594 // Function alignments.
@@ -8389,12 +8400,26 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
83898400 VT.getVectorElementType() == MVT::i32);
83908401 SDValue A = Op.getOperand(1);
83918402 SDValue B = Op.getOperand(2);
8392- assert(A.getSimpleValueType() == B.getSimpleValueType() &&
8393- A.getSimpleValueType().getVectorElementType() == MVT::i8);
8403+ MVT ArgVT = A.getSimpleValueType();
8404+ assert(ArgVT == B.getSimpleValueType() &&
8405+ ArgVT.getVectorElementType() == MVT::i8);
8406+
8407+ MVT ContainerVT = VT;
8408+ if (VT.isFixedLengthVector()) {
8409+ ContainerVT = getContainerForFixedLengthVector(VT);
8410+ Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
8411+ MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
8412+ A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
8413+ B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
8414+ }
8415+
83948416 bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
83958417 unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
8396- auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
8397- return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
8418+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
8419+ SDValue Res = DAG.getNode(Opc, DL, ContainerVT, {A, B, Accum, Mask, VL});
8420+ if (VT.isFixedLengthVector())
8421+ Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
8422+ return Res;
83988423}
83998424
84008425static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
0 commit comments