Skip to content

Commit 77a3f81

Browse files
authored
[RISCV] Custom lower fixed length partial.reduce to zvqdotq (#141180)
This is a follow on to 9b4de7 which handles the fixed vector cases. In retrospect, this is simple enough if probably should have just been part of the original commit, but oh well.
1 parent 1d411f2 commit 77a3f81

File tree

2 files changed

+428
-41
lines changed

2 files changed

+428
-41
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15781578
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
15791579
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
15801580
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
1581+
1582+
if (Subtarget.useRVVForFixedLengthVectors()) {
1583+
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
1584+
if (VT.getVectorElementType() != MVT::i32 ||
1585+
!useRVVForFixedLengthVectorVT(VT))
1586+
continue;
1587+
ElementCount EC = VT.getVectorElementCount();
1588+
MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
1589+
setPartialReduceMLAAction(VT, ArgVT, Custom);
1590+
}
1591+
}
15811592
}
15821593

15831594
// Function alignments.
@@ -8389,12 +8400,26 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
83898400
VT.getVectorElementType() == MVT::i32);
83908401
SDValue A = Op.getOperand(1);
83918402
SDValue B = Op.getOperand(2);
8392-
assert(A.getSimpleValueType() == B.getSimpleValueType() &&
8393-
A.getSimpleValueType().getVectorElementType() == MVT::i8);
8403+
MVT ArgVT = A.getSimpleValueType();
8404+
assert(ArgVT == B.getSimpleValueType() &&
8405+
ArgVT.getVectorElementType() == MVT::i8);
8406+
8407+
MVT ContainerVT = VT;
8408+
if (VT.isFixedLengthVector()) {
8409+
ContainerVT = getContainerForFixedLengthVector(VT);
8410+
Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
8411+
MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
8412+
A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
8413+
B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
8414+
}
8415+
83948416
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
83958417
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
8396-
auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
8397-
return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
8418+
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
8419+
SDValue Res = DAG.getNode(Opc, DL, ContainerVT, {A, B, Accum, Mask, VL});
8420+
if (VT.isFixedLengthVector())
8421+
Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
8422+
return Res;
83988423
}
83998424

84008425
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,

0 commit comments

Comments
 (0)