-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[RISCV] Custom lower fixed length partial.reduce to zvqdotq #141180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Custom lower fixed length partial.reduce to zvqdotq #141180
Conversation
This is a follow on to 9b4de7 which handles the fixed vector cases. In retrospect, this is simple enough if probably should have just been part of the original commit, but oh well.
|
@llvm/pr-subscribers-backend-risc-v Author: Philip Reames (preames) ChangesThis is a follow on to 9b4de7 which handles the fixed vector cases. In retrospect, this is simple enough if probably should have just been part of the original commit, but oh well. Full diff: https://github.com/llvm/llvm-project/pull/141180.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 476596e4e0104..05817ded78438 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1578,6 +1578,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
+
+ if (Subtarget.useRVVForFixedLengthVectors()) {
+ for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
+ if (VT.getVectorElementType() != MVT::i32 ||
+ !useRVVForFixedLengthVectorVT(VT))
+ continue;
+ ElementCount EC = VT.getVectorElementCount();
+ MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
+ setPartialReduceMLAAction(VT, ArgVT, Custom);
+ }
+ }
}
// Function alignments.
@@ -8389,12 +8400,26 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
VT.getVectorElementType() == MVT::i32);
SDValue A = Op.getOperand(1);
SDValue B = Op.getOperand(2);
- assert(A.getSimpleValueType() == B.getSimpleValueType() &&
- A.getSimpleValueType().getVectorElementType() == MVT::i8);
+ MVT ArgVT = A.getSimpleValueType();
+ assert(ArgVT == B.getSimpleValueType() &&
+ ArgVT.getVectorElementType() == MVT::i8);
+
+ MVT ContainerVT = VT;
+ if (VT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(VT);
+ Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
+ MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
+ A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
+ B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
+ }
+
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
- auto [Mask, VL] = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
- return DAG.getNode(Opc, DL, VT, {A, B, Accum, Mask, VL});
+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+ SDValue Res = DAG.getNode(Opc, DL, ContainerVT, {A, B, Accum, Mask, VL});
+ if (VT.isFixedLengthVector())
+ Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
+ return Res;
}
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 7faa810f83236..75be742e07522 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -531,26 +531,279 @@ entry:
ret i32 %sum
}
+define <1 x i32> @vqdot_vv_partial_reduce_v1i32_v4i8(<4 x i8> %a, <4 x i8> %b) {
+; NODOT-LABEL: vqdot_vv_partial_reduce_v1i32_v4i8:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vsext.vf2 v8, v9
+; NODOT-NEXT: vwmul.vv v9, v10, v8
+; NODOT-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; NODOT-NEXT: vslidedown.vi v8, v9, 3
+; NODOT-NEXT: vslidedown.vi v10, v9, 2
+; NODOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v8, v9
+; NODOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
+; NODOT-NEXT: vslidedown.vi v9, v9, 1
+; NODOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; NODOT-NEXT: vadd.vv v9, v9, v10
+; NODOT-NEXT: vadd.vv v8, v9, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_partial_reduce_v1i32_v4i8:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
+; DOT-NEXT: vmv.s.x v10, zero
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv1r.v v8, v10
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <4 x i8> %a to <4 x i32>
+ %b.sext = sext <4 x i8> %b to <4 x i32>
+ %mul = mul <4 x i32> %a.sext, %b.sext
+ %res = call <1 x i32> @llvm.experimental.vector.partial.reduce.add(<1 x i32> zeroinitializer, <4 x i32> %mul)
+ ret <1 x i32> %res
+}
+
+define <2 x i32> @vqdot_vv_partial_reduce_v2i32_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; NODOT-LABEL: vqdot_vv_partial_reduce_v2i32_v8i8:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 8, e16, m1, ta, ma
+; NODOT-NEXT: vsext.vf2 v10, v8
+; NODOT-NEXT: vsext.vf2 v11, v9
+; NODOT-NEXT: vwmul.vv v8, v10, v11
+; NODOT-NEXT: vsetivli zero, 2, e32, m2, ta, ma
+; NODOT-NEXT: vslidedown.vi v10, v8, 6
+; NODOT-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; NODOT-NEXT: vadd.vv v12, v10, v8
+; NODOT-NEXT: vsetivli zero, 2, e32, m2, ta, ma
+; NODOT-NEXT: vslidedown.vi v10, v8, 4
+; NODOT-NEXT: vsetivli zero, 2, e32, m1, ta, ma
+; NODOT-NEXT: vslidedown.vi v8, v8, 2
+; NODOT-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; NODOT-NEXT: vadd.vv v8, v8, v10
+; NODOT-NEXT: vadd.vv v8, v8, v12
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_partial_reduce_v2i32_v8i8:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv1r.v v8, v10
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <8 x i8> %a to <8 x i32>
+ %b.sext = sext <8 x i8> %b to <8 x i32>
+ %mul = mul <8 x i32> %a.sext, %b.sext
+ %res = call <2 x i32> @llvm.experimental.vector.partial.reduce.add(<2 x i32> zeroinitializer, <8 x i32> %mul)
+ ret <2 x i32> %res
+}
-define <4 x i32> @vqdot_vv_partial_reduce(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdot_vv_partial_reduce:
+define <2 x i32> @vqdot_vv_partial_reduce_v2i32_v64i8(<64 x i8> %a, <64 x i8> %b) {
+; CHECK-LABEL: vqdot_vv_partial_reduce_v2i32_v64i8:
; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vwmul.vv v8, v12, v14
-; CHECK-NEXT: vsetivli zero, 4, e32, m4, ta, ma
+; CHECK-NEXT: addi sp, sp, -16
+; CHECK-NEXT: .cfi_def_cfa_offset 16
+; CHECK-NEXT: csrr a0, vlenb
+; CHECK-NEXT: slli a1, a0, 2
+; CHECK-NEXT: add a0, a1, a0
+; CHECK-NEXT: sub sp, sp, a0
+; CHECK-NEXT: .cfi_escape 0x0f, 0x0d, 0x72, 0x00, 0x11, 0x10, 0x22, 0x11, 0x05, 0x92, 0xa2, 0x38, 0x00, 0x1e, 0x22 # sp + 16 + 5 * vlenb
+; CHECK-NEXT: li a0, 32
+; CHECK-NEXT: vsetvli zero, a0, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v24, v8
+; CHECK-NEXT: vsext.vf2 v28, v12
+; CHECK-NEXT: vwmul.vv v16, v24, v28
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v24, v16, 28
+; CHECK-NEXT: vslidedown.vi v0, v16, 26
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v24, v0, v24
+; CHECK-NEXT: csrr a1, vlenb
+; CHECK-NEXT: slli a1, a1, 2
+; CHECK-NEXT: add a1, sp, a1
+; CHECK-NEXT: addi a1, a1, 16
+; CHECK-NEXT: vs1r.v v24, (a1) # vscale x 8-byte Folded Spill
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v0, v16, 24
+; CHECK-NEXT: vslidedown.vi v24, v16, 22
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v24, v24, v0
+; CHECK-NEXT: csrr a1, vlenb
+; CHECK-NEXT: slli a2, a1, 1
+; CHECK-NEXT: add a1, a2, a1
+; CHECK-NEXT: add a1, sp, a1
+; CHECK-NEXT: addi a1, a1, 16
+; CHECK-NEXT: vs1r.v v24, (a1) # vscale x 8-byte Folded Spill
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v0, v16, 20
+; CHECK-NEXT: vslidedown.vi v24, v16, 18
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v24, v24, v0
+; CHECK-NEXT: csrr a1, vlenb
+; CHECK-NEXT: slli a1, a1, 1
+; CHECK-NEXT: add a1, sp, a1
+; CHECK-NEXT: addi a1, a1, 16
+; CHECK-NEXT: vs1r.v v24, (a1) # vscale x 8-byte Folded Spill
+; CHECK-NEXT: vsetvli zero, a0, e8, m4, ta, ma
+; CHECK-NEXT: vslidedown.vx v8, v8, a0
+; CHECK-NEXT: vslidedown.vx v12, v12, a0
+; CHECK-NEXT: vsetvli zero, a0, e16, m4, ta, ma
+; CHECK-NEXT: vsext.vf2 v24, v8
+; CHECK-NEXT: vsext.vf2 v28, v12
+; CHECK-NEXT: vwmul.vv v8, v24, v28
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v0, v8, 28
+; CHECK-NEXT: vslidedown.vi v24, v8, 26
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v24, v24, v0
+; CHECK-NEXT: csrr a0, vlenb
+; CHECK-NEXT: add a0, sp, a0
+; CHECK-NEXT: addi a0, a0, 16
+; CHECK-NEXT: vs1r.v v24, (a0) # vscale x 8-byte Folded Spill
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v24, v8, 24
+; CHECK-NEXT: vslidedown.vi v0, v8, 22
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v24, v0, v24
+; CHECK-NEXT: addi a0, sp, 16
+; CHECK-NEXT: vs1r.v v24, (a0) # vscale x 8-byte Folded Spill
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v24, v8, 20
+; CHECK-NEXT: vslidedown.vi v0, v8, 18
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v7, v0, v24
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v24, v16, 16
+; CHECK-NEXT: vsetivli zero, 2, e32, m4, ta, ma
+; CHECK-NEXT: vslidedown.vi v28, v16, 14
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v6, v28, v24
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v24, v8, 16
+; CHECK-NEXT: vsetivli zero, 2, e32, m4, ta, ma
+; CHECK-NEXT: vslidedown.vi v28, v8, 14
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v5, v28, v24
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v24, v16, 30
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v4, v24, v16
+; CHECK-NEXT: vsetivli zero, 2, e32, m4, ta, ma
+; CHECK-NEXT: vslidedown.vi v20, v16, 12
+; CHECK-NEXT: vslidedown.vi v24, v16, 10
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v3, v24, v20
+; CHECK-NEXT: vsetivli zero, 2, e32, m8, ta, ma
+; CHECK-NEXT: vslidedown.vi v24, v8, 30
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v24, v24, v8
+; CHECK-NEXT: vsetivli zero, 2, e32, m4, ta, ma
; CHECK-NEXT: vslidedown.vi v12, v8, 12
-; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; CHECK-NEXT: vadd.vv v16, v12, v8
-; CHECK-NEXT: vsetivli zero, 4, e32, m4, ta, ma
-; CHECK-NEXT: vslidedown.vi v12, v8, 8
-; CHECK-NEXT: vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT: vslidedown.vi v8, v8, 4
-; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; CHECK-NEXT: vadd.vv v8, v8, v12
-; CHECK-NEXT: vadd.vv v8, v8, v16
+; CHECK-NEXT: vslidedown.vi v20, v8, 10
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v25, v20, v12
+; CHECK-NEXT: vsetivli zero, 2, e32, m2, ta, ma
+; CHECK-NEXT: vslidedown.vi v20, v16, 6
+; CHECK-NEXT: vslidedown.vi v22, v16, 4
+; CHECK-NEXT: vsetivli zero, 2, e32, m1, ta, ma
+; CHECK-NEXT: vslidedown.vi v21, v16, 2
+; CHECK-NEXT: vsetivli zero, 2, e32, m4, ta, ma
+; CHECK-NEXT: vslidedown.vi v12, v16, 8
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v16, v20, v12
+; CHECK-NEXT: vsetivli zero, 2, e32, m2, ta, ma
+; CHECK-NEXT: vslidedown.vi v12, v8, 6
+; CHECK-NEXT: vslidedown.vi v14, v8, 4
+; CHECK-NEXT: vsetivli zero, 2, e32, m1, ta, ma
+; CHECK-NEXT: vslidedown.vi v13, v8, 2
+; CHECK-NEXT: vsetivli zero, 2, e32, m4, ta, ma
+; CHECK-NEXT: vslidedown.vi v8, v8, 8
+; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT: vadd.vv v8, v12, v8
+; CHECK-NEXT: vadd.vv v9, v21, v22
+; CHECK-NEXT: vadd.vv v10, v13, v14
+; CHECK-NEXT: csrr a0, vlenb
+; CHECK-NEXT: slli a0, a0, 2
+; CHECK-NEXT: add a0, sp, a0
+; CHECK-NEXT: addi a0, a0, 16
+; CHECK-NEXT: vl1r.v v11, (a0) # vscale x 8-byte Folded Reload
+; CHECK-NEXT: vadd.vv v11, v11, v4
+; CHECK-NEXT: csrr a0, vlenb
+; CHECK-NEXT: slli a1, a0, 1
+; CHECK-NEXT: add a0, a1, a0
+; CHECK-NEXT: add a0, sp, a0
+; CHECK-NEXT: addi a0, a0, 16
+; CHECK-NEXT: vl1r.v v12, (a0) # vscale x 8-byte Folded Reload
+; CHECK-NEXT: csrr a0, vlenb
+; CHECK-NEXT: slli a0, a0, 1
+; CHECK-NEXT: add a0, sp, a0
+; CHECK-NEXT: addi a0, a0, 16
+; CHECK-NEXT: vl1r.v v13, (a0) # vscale x 8-byte Folded Reload
+; CHECK-NEXT: vadd.vv v12, v13, v12
+; CHECK-NEXT: vadd.vv v13, v3, v6
+; CHECK-NEXT: vadd.vv v9, v9, v16
+; CHECK-NEXT: vadd.vv v11, v12, v11
+; CHECK-NEXT: vadd.vv v9, v9, v13
+; CHECK-NEXT: addi a0, sp, 16
+; CHECK-NEXT: vl1r.v v12, (a0) # vscale x 8-byte Folded Reload
+; CHECK-NEXT: vadd.vv v12, v7, v12
+; CHECK-NEXT: vadd.vv v13, v25, v5
+; CHECK-NEXT: vadd.vv v8, v10, v8
+; CHECK-NEXT: vadd.vv v9, v9, v11
+; CHECK-NEXT: vadd.vv v9, v24, v9
+; CHECK-NEXT: csrr a0, vlenb
+; CHECK-NEXT: add a0, sp, a0
+; CHECK-NEXT: addi a0, a0, 16
+; CHECK-NEXT: vl1r.v v10, (a0) # vscale x 8-byte Folded Reload
+; CHECK-NEXT: vadd.vv v9, v10, v9
+; CHECK-NEXT: vadd.vv v9, v12, v9
+; CHECK-NEXT: vadd.vv v8, v8, v13
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: csrr a0, vlenb
+; CHECK-NEXT: slli a1, a0, 2
+; CHECK-NEXT: add a0, a1, a0
+; CHECK-NEXT: add sp, sp, a0
+; CHECK-NEXT: .cfi_def_cfa sp, 16
+; CHECK-NEXT: addi sp, sp, 16
+; CHECK-NEXT: .cfi_def_cfa_offset 0
; CHECK-NEXT: ret
+entry:
+ %a.sext = sext <64 x i8> %a to <64 x i32>
+ %b.sext = sext <64 x i8> %b to <64 x i32>
+ %mul = mul <64 x i32> %a.sext, %b.sext
+ %res = call <2 x i32> @llvm.experimental.vector.partial.reduce.add(<2 x i32> zeroinitializer, <64 x i32> %mul)
+ ret <2 x i32> %res
+}
+
+define <4 x i32> @vqdot_vv_partial_reduce_v4i32_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; NODOT-LABEL: vqdot_vv_partial_reduce_v4i32_v16i8:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vwmul.vv v8, v12, v14
+; NODOT-NEXT: vsetivli zero, 4, e32, m4, ta, ma
+; NODOT-NEXT: vslidedown.vi v12, v8, 12
+; NODOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; NODOT-NEXT: vadd.vv v16, v12, v8
+; NODOT-NEXT: vsetivli zero, 4, e32, m4, ta, ma
+; NODOT-NEXT: vslidedown.vi v12, v8, 8
+; NODOT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
+; NODOT-NEXT: vslidedown.vi v8, v8, 4
+; NODOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; NODOT-NEXT: vadd.vv v8, v8, v12
+; NODOT-NEXT: vadd.vv v8, v8, v16
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_partial_reduce_v4i32_v16i8:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.v.v v8, v10
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -559,27 +812,77 @@ entry:
ret <4 x i32> %res
}
-define <4 x i32> @vqdot_vv_partial_reduce2(<16 x i8> %a, <16 x i8> %b, <4 x i32> %accum) {
-; CHECK-LABEL: vqdot_vv_partial_reduce2:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v16, v8
-; CHECK-NEXT: vsext.vf2 v18, v9
-; CHECK-NEXT: vwmul.vv v12, v16, v18
-; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; CHECK-NEXT: vadd.vv v16, v10, v12
-; CHECK-NEXT: vsetivli zero, 4, e32, m4, ta, ma
-; CHECK-NEXT: vslidedown.vi v8, v12, 12
-; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; CHECK-NEXT: vadd.vv v16, v8, v16
-; CHECK-NEXT: vsetivli zero, 4, e32, m4, ta, ma
-; CHECK-NEXT: vslidedown.vi v8, v12, 8
-; CHECK-NEXT: vsetivli zero, 4, e32, m2, ta, ma
-; CHECK-NEXT: vslidedown.vi v10, v12, 4
-; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; CHECK-NEXT: vadd.vv v8, v10, v8
-; CHECK-NEXT: vadd.vv v8, v8, v16
-; CHECK-NEXT: ret
+define <16 x i32> @vqdot_vv_partial_reduce_v16i32_v64i8(<64 x i8> %a, <64 x i8> %b) {
+; NODOT-LABEL: vqdot_vv_partial_reduce_v16i32_v64i8:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: li a0, 32
+; NODOT-NEXT: vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT: vslidedown.vx v8, v8, a0
+; NODOT-NEXT: vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v20, v12
+; NODOT-NEXT: vsetvli zero, a0, e8, m4, ta, ma
+; NODOT-NEXT: vslidedown.vx v12, v12, a0
+; NODOT-NEXT: vsetvli zero, a0, e16, m4, ta, ma
+; NODOT-NEXT: vsext.vf2 v24, v8
+; NODOT-NEXT: vsext.vf2 v28, v12
+; NODOT-NEXT: vwmul.vv v8, v16, v20
+; NODOT-NEXT: vwmul.vv v16, v24, v28
+; NODOT-NEXT: vsetivli zero, 16, e32, m8, ta, ma
+; NODOT-NEXT: vslidedown.vi v24, v8, 16
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vadd.vv v8, v24, v8
+; NODOT-NEXT: vadd.vv v24, v8, v16
+; NODOT-NEXT: vsetivli zero, 16, e32, m8, ta, ma
+; NODOT-NEXT: vslidedown.vi v8, v16, 16
+; NODOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT: vadd.vv v8, v8, v24
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_partial_reduce_v16i32_v64i8:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
+; DOT-NEXT: vmv.v.i v16, 0
+; DOT-NEXT: vqdot.vv v16, v8, v12
+; DOT-NEXT: vmv.v.v v8, v16
+; DOT-NEXT: ret
+entry:
+ %a.sext = sext <64 x i8> %a to <64 x i32>
+ %b.sext = sext <64 x i8> %b to <64 x i32>
+ %mul = mul <64 x i32> %a.sext, %b.sext
+ %res = call <16 x i32> @llvm.experimental.vector.partial.reduce.add(<16 x i32> zeroinitializer, <64 x i32> %mul)
+ ret <16 x i32> %res
+}
+
+define <4 x i32> @vqdot_vv_partial_reduce_m1_accum(<16 x i8> %a, <16 x i8> %b, <4 x i32> %accum) {
+; NODOT-LABEL: vqdot_vv_partial_reduce_m1_accum:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v16, v8
+; NODOT-NEXT: vsext.vf2 v18, v9
+; NODOT-NEXT: vwmul.vv v12, v16, v18
+; NODOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; NODOT-NEXT: vadd.vv v16, v10, v12
+; NODOT-NEXT: vsetivli zero, 4, e32, m4, ta, ma
+; NODOT-NEXT: vslidedown.vi v8, v12, 12
+; NODOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; NODOT-NEXT: vadd.vv v16, v8, v16
+; NODOT-NEXT: vsetivli zero, 4, e32, m4, ta, ma
+; NODOT-NEXT: vslidedown.vi v8, v12, 8
+; NODOT-NEXT: vsetivli zero, 4, e32, m2, ta, ma
+; NODOT-NEXT: vslidedown.vi v10, v12, 4
+; NODOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; NODOT-NEXT: vadd.vv v8, v10, v8
+; NODOT-NEXT: vadd.vv v8, v8, v16
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: vqdot_vv_partial_reduce_m1_accum:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT: vqdot.vv v10, v8, v9
+; DOT-NEXT: vmv.v.v v8, v10
+; DOT-NEXT: ret
entry:
%a.sext = sext <16 x i8> %a to <16 x i32>
%b.sext = sext <16 x i8> %b to <16 x i32>
|
|
Are there no tests for vqdotu? |
Hadn't been for fixed vector since they were covered by scalable, but I added vqdotu and vqdotsu (not yet supported) cases. |
This is a follow on to 9b4de7 which handles the fixed vector cases. In retrospect, this is simple enough if probably should have just been part of the original commit, but oh well.