Skip to content

Commit 0759024

Browse files
committed
[RISCV] Lower PARTIAL_REDUCE_[S/U]MLA via zvqdotq
The semantics of the PARTIAL_REDUCE_SMLA with i32 result element, and i8 sources corresponds to vqdot. Analogously PARTIAL_REDUCE_UMLA corresponds to vqdotu. There is currently no vqdotsu equivalent. This patch is a starting place. We can extend this quite a bit more, and I plan to take a look at the fixed vector lowering, the TTI hook to drive loop vectorizer, and to try to integrate the reduction based lowering I'd added for zvqdotq into this flow.
1 parent e4e7a7e commit 0759024

File tree

3 files changed

+355
-209
lines changed

3 files changed

+355
-209
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,6 +1571,14 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15711571
setIndexedStoreAction(ISD::POST_INC, MVT::i32, Legal);
15721572
}
15731573

1574+
if (Subtarget.hasStdExtZvqdotq()) {
1575+
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
1576+
setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
1577+
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
1578+
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
1579+
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
1580+
}
1581+
15741582
// Function alignments.
15751583
const Align FunctionAlignment(Subtarget.hasStdExtCOrZca() ? 2 : 4);
15761584
setMinFunctionAlignment(FunctionAlignment);
@@ -8229,6 +8237,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
82298237
return lowerINIT_TRAMPOLINE(Op, DAG);
82308238
case ISD::ADJUST_TRAMPOLINE:
82318239
return lowerADJUST_TRAMPOLINE(Op, DAG);
8240+
case ISD::PARTIAL_REDUCE_UMLA:
8241+
case ISD::PARTIAL_REDUCE_SMLA:
8242+
return lowerPARTIAL_REDUCE_MLA(Op, DAG);
82328243
}
82338244
}
82348245

@@ -8364,6 +8375,27 @@ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
83648375
return Op.getOperand(0);
83658376
}
83668377

8378+
SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
8379+
SelectionDAG &DAG) const {
8380+
// Currently, only the vqdot and vqdotu case (from zvqdotq) hould be legal.
8381+
// TODO: There are many other sub-cases we could potentially lower, are
8382+
// any of them worthwhile? Ex: via vredsum, vwredsum, vwwmaccu, etc..
8383+
// TODO: PARTIAL_REDUCE_*MLA can't represent a vqdotsu currently.
8384+
SDLoc DL(Op);
8385+
MVT VT = Op.getSimpleValueType();
8386+
SDValue Accum = Op.getOperand(0);
8387+
assert(Accum.getSimpleValueType() == VT &&
8388+
VT.getVectorElementType() == MVT::i32);
8389+
SDValue A = Op.getOperand(1);
8390+
SDValue B = Op.getOperand(2);
8391+
assert(A.getSimpleValueType() == B.getSimpleValueType() &&
8392+
A.getSimpleValueType().getVectorElementType() == MVT::i8);
8393+
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
8394+
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
8395+
auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
8396+
return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
8397+
}
8398+
83678399
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
83688400
SelectionDAG &DAG, unsigned Flags) {
83698401
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ class RISCVTargetLowering : public TargetLowering {
552552

553553
SDValue lowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
554554
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
555+
SDValue lowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
555556

556557
bool isEligibleForTailCallOptimization(
557558
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,

0 commit comments

Comments
 (0)