Skip to content

Commit da7d0f4

Browse files
authored
Merge pull request #5427 from yuanjia111/develop
Optimize the gemv_t_vector.c kernel for RISCV64_ZVL256B target
2 parents b3f247a + c2cc7a3 commit da7d0f4

File tree

1 file changed

+168
-79
lines changed

1 file changed

+168
-79
lines changed

kernel/riscv64/gemv_t_vector.c

Lines changed: 168 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -27,110 +27,199 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

2828
#include "common.h"
2929
#if !defined(DOUBLE)
30-
#define VSETVL(n) RISCV_RVV(vsetvl_e32m2)(n)
31-
#define FLOAT_V_T vfloat32m2_t
30+
#define VSETVL(n) RISCV_RVV(vsetvl_e32m8)(n)
31+
#define VSETVL_MAX_M1 RISCV_RVV(vsetvlmax_e32m1)
32+
#define FLOAT_V_T vfloat32m8_t
3233
#define FLOAT_V_T_M1 vfloat32m1_t
33-
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m2)
34-
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m2)
34+
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m8)
35+
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m8)
3536
#ifdef RISCV_0p10_INTRINSICS
36-
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f32m2_f32m1(v_res, va, vb, gvl)
37+
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f32m8_f32m1(v_res, va, vb, gvl)
3738
#else
38-
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f32m2_f32m1)
39+
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f32m8_f32m1)
3940
#endif
40-
#define VFMACCVV_FLOAT RISCV_RVV(vfmacc_vv_f32m2)
41-
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f32m2)
41+
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f32m8)
42+
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f32m8)
4243
#define VFMVVF_FLOAT_M1 RISCV_RVV(vfmv_v_f_f32m1)
43-
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f32m2)
4444
#define xint_t int
4545
#else
46-
#define VSETVL(n) RISCV_RVV(vsetvl_e64m2)(n)
47-
#define FLOAT_V_T vfloat64m2_t
46+
#define VSETVL(n) RISCV_RVV(vsetvl_e64m8)(n)
47+
#define VSETVL_MAX_M1 RISCV_RVV(vsetvlmax_e64m1)
48+
#define FLOAT_V_T vfloat64m8_t
4849
#define FLOAT_V_T_M1 vfloat64m1_t
49-
#define VLEV_FLOAT RISCV_RVV(vle64_v_f64m2)
50-
#define VLSEV_FLOAT RISCV_RVV(vlse64_v_f64m2)
50+
#define VLEV_FLOAT RISCV_RVV(vle64_v_f64m8)
51+
#define VLSEV_FLOAT RISCV_RVV(vlse64_v_f64m8)
5152
#ifdef RISCV_0p10_INTRINSICS
52-
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f64m2_f64m1(v_res, va, vb, gvl)
53+
#define VFREDSUM_FLOAT(va, vb, gvl) vfredusum_vs_f64m8_f64m1(v_res, va, vb, gvl)
5354
#else
54-
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f64m2_f64m1)
55+
#define VFREDSUM_FLOAT RISCV_RVV(vfredusum_vs_f64m8_f64m1)
5556
#endif
56-
#define VFMACCVV_FLOAT RISCV_RVV(vfmacc_vv_f64m2)
57-
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f64m2)
57+
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f64m8)
58+
#define VFMVVF_FLOAT RISCV_RVV(vfmv_v_f_f64m8)
5859
#define VFMVVF_FLOAT_M1 RISCV_RVV(vfmv_v_f_f64m1)
59-
#define VFMULVV_FLOAT RISCV_RVV(vfmul_vv_f64m2)
6060
#define xint_t long long
6161
#endif
6262

6363
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer)
6464
{
65-
BLASLONG i = 0, j = 0, k = 0;
66-
BLASLONG ix = 0, iy = 0;
67-
FLOAT *a_ptr = a;
68-
FLOAT temp;
65+
BLASLONG i = 0, j = 0, k = 0;
66+
BLASLONG ix = 0, iy = 0;
67+
FLOAT *a_ptr = a;
68+
FLOAT temp;
6969

70-
FLOAT_V_T va, vr, vx;
71-
unsigned int gvl = 0;
72-
FLOAT_V_T_M1 v_res;
70+
FLOAT_V_T va, vr, vx;
71+
unsigned int gvl = 0;
72+
FLOAT_V_T_M1 v_res;
73+
size_t vlmax = VSETVL_MAX_M1();
7374

75+
#ifndef RISCV_0p10_INTRINSICS
76+
FLOAT_V_T va0, va1, va2, va3, vr0, vr1, vr2, vr3;
77+
FLOAT_V_T_M1 vec0, vec1, vec2, vec3;
78+
FLOAT *a_ptrs[4], *y_ptrs[4];
79+
#endif
7480

75-
if(inc_x == 1){
76-
for(i = 0; i < n; i++){
77-
v_res = VFMVVF_FLOAT_M1(0, 1);
78-
gvl = VSETVL(m);
79-
j = 0;
80-
vr = VFMVVF_FLOAT(0, gvl);
81-
for(k = 0; k < m/gvl; k++){
82-
va = VLEV_FLOAT(&a_ptr[j], gvl);
83-
vx = VLEV_FLOAT(&x[j], gvl);
84-
vr = VFMULVV_FLOAT(va, vx, gvl); // could vfmacc here and reduce outside loop
85-
v_res = VFREDSUM_FLOAT(vr, v_res, gvl); // but that reordering diverges far enough from scalar path to make tests fail
86-
j += gvl;
87-
}
88-
if(j < m){
89-
gvl = VSETVL(m-j);
90-
va = VLEV_FLOAT(&a_ptr[j], gvl);
91-
vx = VLEV_FLOAT(&x[j], gvl);
92-
vr = VFMULVV_FLOAT(va, vx, gvl);
93-
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
94-
}
95-
temp = (FLOAT)EXTRACT_FLOAT(v_res);
96-
y[iy] += alpha * temp;
81+
if(inc_x == 1){
82+
#ifndef RISCV_0p10_INTRINSICS
83+
BLASLONG anr = n - n % 4;
84+
for (; i < anr; i += 4) {
85+
gvl = VSETVL(m);
86+
j = 0;
87+
for (int l = 0; l < 4; l++) {
88+
a_ptrs[l] = a + (i + l) * lda;
89+
y_ptrs[l] = y + (i + l) * inc_y;
90+
}
91+
vec0 = VFMVVF_FLOAT_M1(0.0, vlmax);
92+
vec1 = VFMVVF_FLOAT_M1(0.0, vlmax);
93+
vec2 = VFMVVF_FLOAT_M1(0.0, vlmax);
94+
vec3 = VFMVVF_FLOAT_M1(0.0, vlmax);
95+
vr0 = VFMVVF_FLOAT(0.0, gvl);
96+
vr1 = VFMVVF_FLOAT(0.0, gvl);
97+
vr2 = VFMVVF_FLOAT(0.0, gvl);
98+
vr3 = VFMVVF_FLOAT(0.0, gvl);
99+
for (k = 0; k < m / gvl; k++) {
100+
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
101+
va1 = VLEV_FLOAT(a_ptrs[1] + j, gvl);
102+
va2 = VLEV_FLOAT(a_ptrs[2] + j, gvl);
103+
va3 = VLEV_FLOAT(a_ptrs[3] + j, gvl);
97104

105+
vx = VLEV_FLOAT(x + j, gvl);
106+
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
107+
vr1 = VFMULVV_FLOAT(va1, vx, gvl);
108+
vr2 = VFMULVV_FLOAT(va2, vx, gvl);
109+
vr3 = VFMULVV_FLOAT(va3, vx, gvl);
110+
// Floating-point addition does not satisfy the associative law, that is, (a + b) + c ≠ a + (b + c),
111+
// so piecewise multiplication and reduction must be performed inside the loop body.
112+
vec0 = VFREDSUM_FLOAT(vr0, vec0, gvl);
113+
vec1 = VFREDSUM_FLOAT(vr1, vec1, gvl);
114+
vec2 = VFREDSUM_FLOAT(vr2, vec2, gvl);
115+
vec3 = VFREDSUM_FLOAT(vr3, vec3, gvl);
116+
j += gvl;
117+
}
118+
if (j < m) {
119+
gvl = VSETVL(m - j);
120+
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
121+
va1 = VLEV_FLOAT(a_ptrs[1] + j, gvl);
122+
va2 = VLEV_FLOAT(a_ptrs[2] + j, gvl);
123+
va3 = VLEV_FLOAT(a_ptrs[3] + j, gvl);
98124

99-
iy += inc_y;
100-
a_ptr += lda;
101-
}
102-
}else{
103-
BLASLONG stride_x = inc_x * sizeof(FLOAT);
104-
for(i = 0; i < n; i++){
105-
v_res = VFMVVF_FLOAT_M1(0, 1);
106-
gvl = VSETVL(m);
107-
j = 0;
108-
ix = 0;
109-
vr = VFMVVF_FLOAT(0, gvl);
110-
for(k = 0; k < m/gvl; k++){
111-
va = VLEV_FLOAT(&a_ptr[j], gvl);
112-
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
113-
vr = VFMULVV_FLOAT(va, vx, gvl);
114-
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
115-
j += gvl;
116-
ix += inc_x * gvl;
117-
}
118-
if(j < m){
119-
gvl = VSETVL(m-j);
120-
va = VLEV_FLOAT(&a_ptr[j], gvl);
121-
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
122-
vr = VFMULVV_FLOAT(va, vx, gvl);
123-
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
124-
}
125-
temp = (FLOAT)EXTRACT_FLOAT(v_res);
126-
y[iy] += alpha * temp;
125+
vx = VLEV_FLOAT(x + j, gvl);
126+
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
127+
vr1 = VFMULVV_FLOAT(va1, vx, gvl);
128+
vr2 = VFMULVV_FLOAT(va2, vx, gvl);
129+
vr3 = VFMULVV_FLOAT(va3, vx, gvl);
130+
vec0 = VFREDSUM_FLOAT(vr0, vec0, gvl);
131+
vec1 = VFREDSUM_FLOAT(vr1, vec1, gvl);
132+
vec2 = VFREDSUM_FLOAT(vr2, vec2, gvl);
133+
vec3 = VFREDSUM_FLOAT(vr3, vec3, gvl);
134+
}
135+
*y_ptrs[0] += alpha * (FLOAT)(EXTRACT_FLOAT(vec0));
136+
*y_ptrs[1] += alpha * (FLOAT)(EXTRACT_FLOAT(vec1));
137+
*y_ptrs[2] += alpha * (FLOAT)(EXTRACT_FLOAT(vec2));
138+
*y_ptrs[3] += alpha * (FLOAT)(EXTRACT_FLOAT(vec3));
139+
}
140+
// deal with the tail
141+
for (; i < n; i++) {
142+
v_res = VFMVVF_FLOAT_M1(0, vlmax);
143+
gvl = VSETVL(m);
144+
j = 0;
145+
a_ptrs[0] = a + i * lda;
146+
y_ptrs[0] = y + i * inc_y;
147+
vr0 = VFMVVF_FLOAT(0, gvl);
148+
for (k = 0; k < m / gvl; k++) {
149+
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
150+
vx = VLEV_FLOAT(x + j, gvl);
151+
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
152+
v_res = VFREDSUM_FLOAT(vr0, v_res, gvl);
153+
j += gvl;
154+
}
155+
if (j < m) {
156+
gvl = VSETVL(m - j);
157+
va0 = VLEV_FLOAT(a_ptrs[0] + j, gvl);
158+
vx = VLEV_FLOAT(x + j, gvl);
159+
vr0 = VFMULVV_FLOAT(va0, vx, gvl);
160+
v_res = VFREDSUM_FLOAT(vr0, v_res, gvl);
161+
}
162+
*y_ptrs[0] += alpha * (FLOAT)(EXTRACT_FLOAT(v_res));
163+
}
164+
#else
165+
for(i = 0; i < n; i++){
166+
v_res = VFMVVF_FLOAT_M1(0, 1);
167+
gvl = VSETVL(m);
168+
j = 0;
169+
vr = VFMVVF_FLOAT(0, gvl);
170+
for(k = 0; k < m/gvl; k++){
171+
va = VLEV_FLOAT(&a_ptr[j], gvl);
172+
vx = VLEV_FLOAT(&x[j], gvl);
173+
vr = VFMULVV_FLOAT(va, vx, gvl); // could vfmacc here and reduce outside loop
174+
v_res = VFREDSUM_FLOAT(vr, v_res, gvl); // but that reordering diverges far enough from scalar path to make tests fail
175+
j += gvl;
176+
}
177+
if(j < m){
178+
gvl = VSETVL(m-j);
179+
va = VLEV_FLOAT(&a_ptr[j], gvl);
180+
vx = VLEV_FLOAT(&x[j], gvl);
181+
vr = VFMULVV_FLOAT(va, vx, gvl);
182+
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
183+
}
184+
temp = (FLOAT)EXTRACT_FLOAT(v_res);
185+
y[iy] += alpha * temp;
127186

128187

129-
iy += inc_y;
130-
a_ptr += lda;
188+
iy += inc_y;
189+
a_ptr += lda;
190+
}
191+
#endif
192+
} else {
193+
BLASLONG stride_x = inc_x * sizeof(FLOAT);
194+
for(i = 0; i < n; i++){
195+
v_res = VFMVVF_FLOAT_M1(0, 1);
196+
gvl = VSETVL(m);
197+
j = 0;
198+
ix = 0;
199+
vr = VFMVVF_FLOAT(0, gvl);
200+
for(k = 0; k < m/gvl; k++){
201+
va = VLEV_FLOAT(&a_ptr[j], gvl);
202+
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
203+
vr = VFMULVV_FLOAT(va, vx, gvl);
204+
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
205+
j += gvl;
206+
ix += inc_x * gvl;
207+
}
208+
if(j < m){
209+
gvl = VSETVL(m-j);
210+
va = VLEV_FLOAT(&a_ptr[j], gvl);
211+
vx = VLSEV_FLOAT(&x[ix], stride_x, gvl);
212+
vr = VFMULVV_FLOAT(va, vx, gvl);
213+
v_res = VFREDSUM_FLOAT(vr, v_res, gvl);
131214
}
132-
}
215+
temp = (FLOAT)EXTRACT_FLOAT(v_res);
216+
y[iy] += alpha * temp;
133217

134218

135-
return(0);
219+
iy += inc_y;
220+
a_ptr += lda;
221+
}
222+
}
223+
224+
return (0);
136225
}

0 commit comments

Comments
 (0)