Skip to content

Commit 39d4dfb

Browse files
authored
[RISCV] Incorporate scalar addends to extend vector multiply accumulate chains (#168660)
Previously, the following: %mul0 = mul nsw <8 x i32> %m00, %m01 %mul1 = mul nsw <8 x i32> %m10, %m11 %add0 = add <8 x i32> %mul0, splat (i32 32) %add1 = add <8 x i32> %add0, %mul1 lowered to: vsetivli zero, 8, e32, m2, ta, ma vmul.vv v8, v8, v9 vmacc.vv v8, v11, v10 li a0, 32 vadd.vx v8, v8, a0 After this patch, now lowers to: li a0, 32 vsetivli zero, 8, e32, m2, ta, ma vmv.v.x v12, a0 vmadd.vv v8, v9, v12 vmacc.vv v8, v11, v10 Modeled on 0cc981e from the AArch64 backend. C-code for the example case (`clang -O3 -S -mcpu=sifive-x280`): ``` int madd_fail(int a, int b, int * restrict src, int * restrict dst, int loop_bound) { for (int i = 0; i < loop_bound; i += 2) { dst[i] = src[i] * a + src[i + 1] * b + 32; } } ```
1 parent f8a8039 commit 39d4dfb

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25722,3 +25722,17 @@ bool RISCVTargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
2572225722

2572325723
return VT.getSizeInBits() <= Subtarget.getXLen();
2572425724
}
25725+
25726+
bool RISCVTargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
25727+
SDValue N1) const {
25728+
if (!N0.hasOneUse())
25729+
return false;
25730+
25731+
// Avoid reassociating expressions that can be lowered to vector
25732+
// multiply accumulate (i.e. add (mul x, y), z)
25733+
if (N0.getOpcode() == ISD::ADD && N1.getOpcode() == ISD::MUL &&
25734+
(N0.getValueType().isVector() && Subtarget.hasVInstructions()))
25735+
return false;
25736+
25737+
return true;
25738+
}

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,11 @@ class RISCVTargetLowering : public TargetLowering {
470470

471471
bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
472472

473+
/// Control the following reassociation of operands: (op (op x, c1), y) -> (op
474+
/// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
475+
bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
476+
SDValue N1) const override;
477+
473478
/// Match a mask which "spreads" the leading elements of a vector evenly
474479
/// across the result. Factor is the spread amount, and Index is the
475480
/// offset applied.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=riscv64 -mattr=+m,+v < %s | FileCheck %s
3+
4+
define i32 @madd_scalar(i32 %m00, i32 %m01, i32 %m10, i32 %m11) nounwind {
5+
; CHECK-LABEL: madd_scalar:
6+
; CHECK: # %bb.0: # %entry
7+
; CHECK-NEXT: mul a0, a0, a1
8+
; CHECK-NEXT: mul a1, a2, a3
9+
; CHECK-NEXT: add a0, a0, a1
10+
; CHECK-NEXT: addiw a0, a0, 32
11+
; CHECK-NEXT: ret
12+
entry:
13+
%mul0 = mul i32 %m00, %m01
14+
%mul1 = mul i32 %m10, %m11
15+
%add0 = add i32 %mul0, 32
16+
%add1 = add i32 %add0, %mul1
17+
ret i32 %add1
18+
}
19+
20+
define <8 x i32> @vmadd_non_constant(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11, <8 x i32> %addend) {
21+
; CHECK-LABEL: vmadd_non_constant:
22+
; CHECK: # %bb.0: # %entry
23+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
24+
; CHECK-NEXT: vmadd.vv v8, v10, v16
25+
; CHECK-NEXT: vmacc.vv v8, v14, v12
26+
; CHECK-NEXT: ret
27+
entry:
28+
%mul0 = mul <8 x i32> %m00, %m01
29+
%mul1 = mul <8 x i32> %m10, %m11
30+
%add0 = add <8 x i32> %mul0, %addend
31+
%add1 = add <8 x i32> %add0, %mul1
32+
ret <8 x i32> %add1
33+
}
34+
35+
define <vscale x 1 x i32> @vmadd_vscale_no_chain(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01) {
36+
; CHECK-LABEL: vmadd_vscale_no_chain:
37+
; CHECK: # %bb.0: # %entry
38+
; CHECK-NEXT: li a0, 32
39+
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
40+
; CHECK-NEXT: vmv.v.x v10, a0
41+
; CHECK-NEXT: vmadd.vv v8, v9, v10
42+
; CHECK-NEXT: ret
43+
entry:
44+
%mul = mul <vscale x 1 x i32> %m00, %m01
45+
%add = add <vscale x 1 x i32> %mul, splat (i32 32)
46+
ret <vscale x 1 x i32> %add
47+
}
48+
49+
define <8 x i32> @vmadd_fixed_no_chain(<8 x i32> %m00, <8 x i32> %m01) {
50+
; CHECK-LABEL: vmadd_fixed_no_chain:
51+
; CHECK: # %bb.0: # %entry
52+
; CHECK-NEXT: li a0, 32
53+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
54+
; CHECK-NEXT: vmv.v.x v12, a0
55+
; CHECK-NEXT: vmadd.vv v8, v10, v12
56+
; CHECK-NEXT: ret
57+
entry:
58+
%mul = mul <8 x i32> %m00, %m01
59+
%add = add <8 x i32> %mul, splat (i32 32)
60+
ret <8 x i32> %add
61+
}
62+
63+
define <vscale x 1 x i32> @vmadd_vscale(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11) {
64+
; CHECK-LABEL: vmadd_vscale:
65+
; CHECK: # %bb.0: # %entry
66+
; CHECK-NEXT: li a0, 32
67+
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
68+
; CHECK-NEXT: vmv.v.x v12, a0
69+
; CHECK-NEXT: vmadd.vv v8, v9, v12
70+
; CHECK-NEXT: vmacc.vv v8, v11, v10
71+
; CHECK-NEXT: ret
72+
entry:
73+
%mul0 = mul <vscale x 1 x i32> %m00, %m01
74+
%mul1 = mul <vscale x 1 x i32> %m10, %m11
75+
%add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
76+
%add1 = add <vscale x 1 x i32> %add0, %mul1
77+
ret <vscale x 1 x i32> %add1
78+
}
79+
80+
define <8 x i32> @vmadd_fixed(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11) {
81+
; CHECK-LABEL: vmadd_fixed:
82+
; CHECK: # %bb.0: # %entry
83+
; CHECK-NEXT: li a0, 32
84+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
85+
; CHECK-NEXT: vmv.v.x v16, a0
86+
; CHECK-NEXT: vmadd.vv v8, v10, v16
87+
; CHECK-NEXT: vmacc.vv v8, v14, v12
88+
; CHECK-NEXT: ret
89+
entry:
90+
%mul0 = mul <8 x i32> %m00, %m01
91+
%mul1 = mul <8 x i32> %m10, %m11
92+
%add0 = add <8 x i32> %mul0, splat (i32 32)
93+
%add1 = add <8 x i32> %add0, %mul1
94+
ret <8 x i32> %add1
95+
}
96+
97+
define <vscale x 1 x i32> @vmadd_vscale_long(<vscale x 1 x i32> %m00, <vscale x 1 x i32> %m01, <vscale x 1 x i32> %m10, <vscale x 1 x i32> %m11,
98+
; CHECK-LABEL: vmadd_vscale_long:
99+
; CHECK: # %bb.0: # %entry
100+
; CHECK-NEXT: li a0, 32
101+
; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
102+
; CHECK-NEXT: vmv.v.x v16, a0
103+
; CHECK-NEXT: vmadd.vv v8, v9, v16
104+
; CHECK-NEXT: vmacc.vv v8, v11, v10
105+
; CHECK-NEXT: vmacc.vv v8, v13, v12
106+
; CHECK-NEXT: vmacc.vv v8, v15, v14
107+
; CHECK-NEXT: ret
108+
<vscale x 1 x i32> %m20, <vscale x 1 x i32> %m21, <vscale x 1 x i32> %m30, <vscale x 1 x i32> %m31) {
109+
entry:
110+
%mul0 = mul <vscale x 1 x i32> %m00, %m01
111+
%mul1 = mul <vscale x 1 x i32> %m10, %m11
112+
%mul2 = mul <vscale x 1 x i32> %m20, %m21
113+
%mul3 = mul <vscale x 1 x i32> %m30, %m31
114+
%add0 = add <vscale x 1 x i32> %mul0, splat (i32 32)
115+
%add1 = add <vscale x 1 x i32> %add0, %mul1
116+
%add2 = add <vscale x 1 x i32> %add1, %mul2
117+
%add3 = add <vscale x 1 x i32> %add2, %mul3
118+
ret <vscale x 1 x i32> %add3
119+
}
120+
121+
define <8 x i32> @vmadd_fixed_long(<8 x i32> %m00, <8 x i32> %m01, <8 x i32> %m10, <8 x i32> %m11,
122+
; CHECK-LABEL: vmadd_fixed_long:
123+
; CHECK: # %bb.0: # %entry
124+
; CHECK-NEXT: li a0, 32
125+
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
126+
; CHECK-NEXT: vmv.v.x v24, a0
127+
; CHECK-NEXT: vmadd.vv v8, v10, v24
128+
; CHECK-NEXT: vmacc.vv v8, v14, v12
129+
; CHECK-NEXT: vmacc.vv v8, v18, v16
130+
; CHECK-NEXT: vmacc.vv v8, v22, v20
131+
; CHECK-NEXT: ret
132+
<8 x i32> %m20, <8 x i32> %m21, <8 x i32> %m30, <8 x i32> %m31) {
133+
entry:
134+
%mul0 = mul <8 x i32> %m00, %m01
135+
%mul1 = mul <8 x i32> %m10, %m11
136+
%mul2 = mul <8 x i32> %m20, %m21
137+
%mul3 = mul <8 x i32> %m30, %m31
138+
%add0 = add <8 x i32> %mul0, splat (i32 32)
139+
%add1 = add <8 x i32> %add0, %mul1
140+
%add2 = add <8 x i32> %add1, %mul2
141+
%add3 = add <8 x i32> %add2, %mul3
142+
ret <8 x i32> %add3
143+
}

0 commit comments

Comments
 (0)