Skip to content

Commit 30ab2f1

Browse files
committed
[RISCV] Fold add_vl into accumulator opeerand of vqdot*
If we have a add_vl following a vqdot* instruction, we can move the add before the vqdot instead. For cases where the prior accumulator was zero, we can fold the add into the vqdot* instruction entirely. This directly parallels the folding we do for multiply add variants.
1 parent 73165de commit 30ab2f1

File tree

2 files changed

+82
-20
lines changed

2 files changed

+82
-20
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18459,9 +18459,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
1845918459
return DAG.getNode(Opc, DL, VT, Ops);
1846018460
}
1846118461

18462-
static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
18463-
ISD::MemIndexType &IndexType,
18464-
RISCVTargetLowering::DAGCombinerInfo &DCI) {
18462+
static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
18463+
const RISCVSubtarget &Subtarget) {
18464+
18465+
assert(N->getOpcode() == RISCVISD::ADD_VL);
18466+
18467+
if (!N->getValueType(0).isVector())
18468+
return SDValue();
18469+
18470+
SDValue Addend = N->getOperand(0);
18471+
SDValue DotOp = N->getOperand(1);
18472+
18473+
SDValue AddPassthruOp = N->getOperand(2);
18474+
if (!AddPassthruOp.isUndef())
18475+
return SDValue();
18476+
18477+
auto IsVqdotqOpc = [](unsigned Opc) {
18478+
switch (Opc) {
18479+
case RISCVISD::VQDOT_VL:
18480+
case RISCVISD::VQDOTU_VL:
18481+
case RISCVISD::VQDOTSU_VL:
18482+
return true;
18483+
default:
18484+
return false;
18485+
}
18486+
};
18487+
18488+
if (!IsVqdotqOpc(DotOp.getOpcode()))
18489+
std::swap(Addend, DotOp);
18490+
18491+
if (!IsVqdotqOpc(DotOp.getOpcode()))
18492+
return SDValue();
18493+
18494+
SDValue AddMask = N->getOperand(3);
18495+
SDValue AddVL = N->getOperand(4);
18496+
18497+
SDValue MulVL = DotOp.getOperand(4);
18498+
if (AddVL != MulVL)
18499+
return SDValue();
18500+
18501+
if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
18502+
AddMask.getOperand(0) != MulVL)
18503+
return SDValue();
18504+
18505+
SDValue AccumOp = DotOp.getOperand(2);
18506+
bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
18507+
// Peak through fixed to scalable
18508+
if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
18509+
AccumOp.getOperand(0).isUndef())
18510+
IsNullAdd =
18511+
ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
18512+
18513+
SDLoc DL(N);
18514+
EVT VT = N->getValueType(0);
18515+
// The manual constant folding is required, this case is not constant folded
18516+
// or combined.
18517+
if (!IsNullAdd)
18518+
Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
18519+
DAG.getUNDEF(VT), AddMask, AddVL);
18520+
18521+
SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
18522+
DotOp.getOperand(3), DotOp->getOperand(4)};
18523+
return DAG.getNode(DotOp->getOpcode(), DL, VT, Ops);
18524+
}
18525+
18526+
static bool
18527+
legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
18528+
ISD::MemIndexType &IndexType,
18529+
RISCVTargetLowering::DAGCombinerInfo &DCI) {
1846518530
if (!DCI.isBeforeLegalize())
1846618531
return false;
1846718532

@@ -19582,6 +19647,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1958219647
case RISCVISD::ADD_VL:
1958319648
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1958419649
return V;
19650+
if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
19651+
return V;
1958519652
return combineToVWMACC(N, DAG, Subtarget);
1958619653
case RISCVISD::VWADD_W_VL:
1958719654
case RISCVISD::VWADDU_W_VL:

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,10 @@ define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
314314
; DOT-LABEL: vqdot_vv_accum:
315315
; DOT: # %bb.0: # %entry
316316
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
317-
; DOT-NEXT: vmv.v.i v10, 0
318-
; DOT-NEXT: vqdot.vv v10, v8, v9
319-
; DOT-NEXT: vadd.vv v8, v10, v12
317+
; DOT-NEXT: vmv1r.v v16, v12
318+
; DOT-NEXT: vqdot.vv v16, v8, v9
320319
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
321-
; DOT-NEXT: vmv.v.v v12, v8
320+
; DOT-NEXT: vmv.v.v v12, v16
322321
; DOT-NEXT: vmv.s.x v8, zero
323322
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
324323
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -349,11 +348,10 @@ define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
349348
; DOT-LABEL: vqdotu_vv_accum:
350349
; DOT: # %bb.0: # %entry
351350
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
352-
; DOT-NEXT: vmv.v.i v10, 0
353-
; DOT-NEXT: vqdotu.vv v10, v8, v9
354-
; DOT-NEXT: vadd.vv v8, v10, v12
351+
; DOT-NEXT: vmv1r.v v16, v12
352+
; DOT-NEXT: vqdotu.vv v16, v8, v9
355353
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
356-
; DOT-NEXT: vmv.v.v v12, v8
354+
; DOT-NEXT: vmv.v.v v12, v16
357355
; DOT-NEXT: vmv.s.x v8, zero
358356
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
359357
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -384,11 +382,10 @@ define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
384382
; DOT-LABEL: vqdotsu_vv_accum:
385383
; DOT: # %bb.0: # %entry
386384
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
387-
; DOT-NEXT: vmv.v.i v10, 0
388-
; DOT-NEXT: vqdotsu.vv v10, v8, v9
389-
; DOT-NEXT: vadd.vv v8, v10, v12
385+
; DOT-NEXT: vmv1r.v v16, v12
386+
; DOT-NEXT: vqdotsu.vv v16, v8, v9
390387
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
391-
; DOT-NEXT: vmv.v.v v12, v8
388+
; DOT-NEXT: vmv.v.v v12, v16
392389
; DOT-NEXT: vmv.s.x v8, zero
393390
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
394391
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -516,12 +513,10 @@ define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %
516513
; DOT: # %bb.0: # %entry
517514
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
518515
; DOT-NEXT: vmv.v.i v12, 0
519-
; DOT-NEXT: vmv.v.i v13, 0
520516
; DOT-NEXT: vqdot.vv v12, v8, v9
521-
; DOT-NEXT: vqdot.vv v13, v10, v11
522-
; DOT-NEXT: vadd.vv v8, v12, v13
523-
; DOT-NEXT: vmv.s.x v9, zero
524-
; DOT-NEXT: vredsum.vs v8, v8, v9
517+
; DOT-NEXT: vqdot.vv v12, v10, v11
518+
; DOT-NEXT: vmv.s.x v8, zero
519+
; DOT-NEXT: vredsum.vs v8, v12, v8
525520
; DOT-NEXT: vmv.x.s a0, v8
526521
; DOT-NEXT: ret
527522
entry:

0 commit comments

Comments
 (0)