@@ -1578,6 +1578,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
1578
1578
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
1579
1579
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
1580
1580
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
1581
+
1582
+ if (Subtarget.useRVVForFixedLengthVectors()) {
1583
+ for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
1584
+ if (VT.getVectorElementType() != MVT::i32 ||
1585
+ !useRVVForFixedLengthVectorVT(VT))
1586
+ continue;
1587
+ ElementCount EC = VT.getVectorElementCount();
1588
+ MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
1589
+ setPartialReduceMLAAction(VT, ArgVT, Custom);
1590
+ }
1591
+ }
1581
1592
}
1582
1593
1583
1594
// Function alignments.
@@ -8389,12 +8400,26 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
8389
8400
VT.getVectorElementType() == MVT::i32);
8390
8401
SDValue A = Op.getOperand(1);
8391
8402
SDValue B = Op.getOperand(2);
8392
- assert(A.getSimpleValueType() == B.getSimpleValueType() &&
8393
- A.getSimpleValueType().getVectorElementType() == MVT::i8);
8403
+ MVT ArgVT = A.getSimpleValueType();
8404
+ assert(ArgVT == B.getSimpleValueType() &&
8405
+ ArgVT.getVectorElementType() == MVT::i8);
8406
+
8407
+ MVT ContainerVT = VT;
8408
+ if (VT.isFixedLengthVector()) {
8409
+ ContainerVT = getContainerForFixedLengthVector(VT);
8410
+ Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
8411
+ MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
8412
+ A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
8413
+ B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
8414
+ }
8415
+
8394
8416
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
8395
8417
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
8396
- auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
8397
- return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
8418
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
8419
+ SDValue Res = DAG.getNode(Opc, DL, ContainerVT, {A, B, Accum, Mask, VL});
8420
+ if (VT.isFixedLengthVector())
8421
+ Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
8422
+ return Res;
8398
8423
}
8399
8424
8400
8425
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
0 commit comments