From 30ab2f15310675baf7bbc82195837ee7fb50d9e8 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Wed, 7 May 2025 08:45:26 -0700 Subject: [PATCH 1/2] [RISCV] Fold add_vl into accumulator opeerand of vqdot* If we have a add_vl following a vqdot* instruction, we can move the add before the vqdot instead. For cases where the prior accumulator was zero, we can fold the add into the vqdot* instruction entirely. This directly parallels the folding we do for multiply add variants. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 73 ++++++++++++++++++- .../RISCV/rvv/fixed-vectors-zvqdotq.ll | 29 +++----- 2 files changed, 82 insertions(+), 20 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c53550ea3b23b..93aabdc004b42 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18459,9 +18459,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, return DAG.getNode(Opc, DL, VT, Ops); } -static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index, - ISD::MemIndexType &IndexType, - RISCVTargetLowering::DAGCombinerInfo &DCI) { +static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + + assert(N->getOpcode() == RISCVISD::ADD_VL); + + if (!N->getValueType(0).isVector()) + return SDValue(); + + SDValue Addend = N->getOperand(0); + SDValue DotOp = N->getOperand(1); + + SDValue AddPassthruOp = N->getOperand(2); + if (!AddPassthruOp.isUndef()) + return SDValue(); + + auto IsVqdotqOpc = [](unsigned Opc) { + switch (Opc) { + case RISCVISD::VQDOT_VL: + case RISCVISD::VQDOTU_VL: + case RISCVISD::VQDOTSU_VL: + return true; + default: + return false; + } + }; + + if (!IsVqdotqOpc(DotOp.getOpcode())) + std::swap(Addend, DotOp); + + if (!IsVqdotqOpc(DotOp.getOpcode())) + return SDValue(); + + SDValue AddMask = N->getOperand(3); + SDValue AddVL = N->getOperand(4); + + SDValue MulVL = DotOp.getOperand(4); + if (AddVL != MulVL) + return SDValue(); + + if (AddMask.getOpcode() != RISCVISD::VMSET_VL || + AddMask.getOperand(0) != MulVL) + return SDValue(); + + SDValue AccumOp = DotOp.getOperand(2); + bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode()); + // Peak through fixed to scalable + if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR && + AccumOp.getOperand(0).isUndef()) + IsNullAdd = + ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode()); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + // The manual constant folding is required, this case is not constant folded + // or combined. + if (!IsNullAdd) + Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend, + DAG.getUNDEF(VT), AddMask, AddVL); + + SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend, + DotOp.getOperand(3), DotOp->getOperand(4)}; + return DAG.getNode(DotOp->getOpcode(), DL, VT, Ops); +} + +static bool +legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index, + ISD::MemIndexType &IndexType, + RISCVTargetLowering::DAGCombinerInfo &DCI) { if (!DCI.isBeforeLegalize()) return false; @@ -19582,6 +19647,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, case RISCVISD::ADD_VL: if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget)) return V; + if (SDValue V = combineVqdotAccum(N, DAG, Subtarget)) + return V; return combineToVWMACC(N, DAG, Subtarget); case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll index e5546ad404c1b..ff61ef82176e6 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll @@ -314,11 +314,10 @@ define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) { ; DOT-LABEL: vqdot_vv_accum: ; DOT: # %bb.0: # %entry ; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma -; DOT-NEXT: vmv.v.i v10, 0 -; DOT-NEXT: vqdot.vv v10, v8, v9 -; DOT-NEXT: vadd.vv v8, v10, v12 +; DOT-NEXT: vmv1r.v v16, v12 +; DOT-NEXT: vqdot.vv v16, v8, v9 ; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma -; DOT-NEXT: vmv.v.v v12, v8 +; DOT-NEXT: vmv.v.v v12, v16 ; DOT-NEXT: vmv.s.x v8, zero ; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma ; DOT-NEXT: vredsum.vs v8, v12, v8 @@ -349,11 +348,10 @@ define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) { ; DOT-LABEL: vqdotu_vv_accum: ; DOT: # %bb.0: # %entry ; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma -; DOT-NEXT: vmv.v.i v10, 0 -; DOT-NEXT: vqdotu.vv v10, v8, v9 -; DOT-NEXT: vadd.vv v8, v10, v12 +; DOT-NEXT: vmv1r.v v16, v12 +; DOT-NEXT: vqdotu.vv v16, v8, v9 ; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma -; DOT-NEXT: vmv.v.v v12, v8 +; DOT-NEXT: vmv.v.v v12, v16 ; DOT-NEXT: vmv.s.x v8, zero ; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma ; DOT-NEXT: vredsum.vs v8, v12, v8 @@ -384,11 +382,10 @@ define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) { ; DOT-LABEL: vqdotsu_vv_accum: ; DOT: # %bb.0: # %entry ; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma -; DOT-NEXT: vmv.v.i v10, 0 -; DOT-NEXT: vqdotsu.vv v10, v8, v9 -; DOT-NEXT: vadd.vv v8, v10, v12 +; DOT-NEXT: vmv1r.v v16, v12 +; DOT-NEXT: vqdotsu.vv v16, v8, v9 ; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma -; DOT-NEXT: vmv.v.v v12, v8 +; DOT-NEXT: vmv.v.v v12, v16 ; DOT-NEXT: vmv.s.x v8, zero ; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma ; DOT-NEXT: vredsum.vs v8, v12, v8 @@ -516,12 +513,10 @@ define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> % ; DOT: # %bb.0: # %entry ; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma ; DOT-NEXT: vmv.v.i v12, 0 -; DOT-NEXT: vmv.v.i v13, 0 ; DOT-NEXT: vqdot.vv v12, v8, v9 -; DOT-NEXT: vqdot.vv v13, v10, v11 -; DOT-NEXT: vadd.vv v8, v12, v13 -; DOT-NEXT: vmv.s.x v9, zero -; DOT-NEXT: vredsum.vs v8, v8, v9 +; DOT-NEXT: vqdot.vv v12, v10, v11 +; DOT-NEXT: vmv.s.x v8, zero +; DOT-NEXT: vredsum.vs v8, v12, v8 ; DOT-NEXT: vmv.x.s a0, v8 ; DOT-NEXT: ret entry: From 6607abe26d3b1769cb6e48b8b71318fb9895324d Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Mon, 12 May 2025 14:05:52 -0700 Subject: [PATCH 2/2] Address review comments --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 93aabdc004b42..bca3926f5ecc6 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18395,7 +18395,6 @@ static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG, static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD); if (N->getValueType(0).isFixedLengthVector()) @@ -18504,7 +18503,7 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG, SDValue AccumOp = DotOp.getOperand(2); bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode()); - // Peak through fixed to scalable + // Peek through fixed to scalable if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR && AccumOp.getOperand(0).isUndef()) IsNullAdd =