Skip to content

Commit 2afd741

Browse files
author
tingbo.liao
committed
Optimize the rotm kernel with RVV intrinsic.
Signed-off-by: tingbo.liao <[email protected]>
1 parent a107547 commit 2afd741

File tree

3 files changed

+249
-0
lines changed

3 files changed

+249
-0
lines changed

common_riscv64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ static inline int blas_quickdivide(blasint x, blasint y){
9191

9292
#if defined(C910V) || defined(RISCV64_ZVL256B) || defined(RISCV64_ZVL128B) || defined(x280)
9393
# include <riscv_vector.h>
94+
#define RISCV_SIMD
9495
#endif
9596

9697
#if defined( __riscv_xtheadc ) && defined( __riscv_v ) && ( __riscv_v <= 7000 )

interface/rotm.c

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,26 @@
33
#include "functable.h"
44
#endif
55

6+
#if defined(RISCV_SIMD)
7+
#if !defined(DOUBLE)
8+
#define VSETVL(n) __riscv_vsetvl_e32m8(n)
9+
#define FLOAT_V_T vfloat32m8_t
10+
#define VLSEV_FLOAT __riscv_vlse32_v_f32m8
11+
#define VSSEV_FLOAT __riscv_vsse32_v_f32m8
12+
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f32m8
13+
#define VFMULVF_FLOAT __riscv_vfmul_vf_f32m8
14+
#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f32m8
15+
#else
16+
#define VSETVL(n) __riscv_vsetvl_e64m8(n)
17+
#define FLOAT_V_T vfloat64m8_t
18+
#define VLSEV_FLOAT __riscv_vlse64_v_f64m8
19+
#define VSSEV_FLOAT __riscv_vsse64_v_f64m8
20+
#define VFMACCVF_FLOAT __riscv_vfmacc_vf_f64m8
21+
#define VFMULVF_FLOAT __riscv_vfmul_vf_f64m8
22+
#define VFMSACVF_FLOAT __riscv_vfmsac_vf_f64m8
23+
#endif
24+
#endif
25+
626
#ifndef CBLAS
727

828
void NAME(blasint *N, FLOAT *dx, blasint *INCX, FLOAT *dy, blasint *INCY, FLOAT *dparam){
@@ -25,6 +45,11 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
2545
FLOAT dh11, dh12, dh22, dh21, dflag;
2646
blasint nsteps;
2747

48+
#if defined(RISCV_SIMD)
49+
FLOAT_V_T v_w, v_z__, v_dx, v_dy;
50+
blasint stride, stride_x, stride_y, offset;
51+
#endif
52+
2853
#ifndef CBLAS
2954
PRINT_DEBUG_CNAME;
3055
#else
@@ -53,26 +78,74 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
5378
dh21 = dparam[3];
5479
i__1 = nsteps;
5580
i__2 = incx;
81+
#if !defined(RISCV_SIMD)
5682
for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) {
5783
w = dx[i__];
5884
z__ = dy[i__];
5985
dx[i__] = w + z__ * dh12;
6086
dy[i__] = w * dh21 + z__;
6187
/* L20: */
6288
}
89+
#else
90+
if(i__2 < 0){
91+
offset = i__1 - 2;
92+
dx += offset;
93+
dy += offset;
94+
i__1 = -i__1;
95+
i__2 = -i__2;
96+
}
97+
stride = i__2 * sizeof(FLOAT);
98+
n = i__1 / i__2;
99+
for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) {
100+
vl = VSETVL(n);
101+
102+
v_w = VLSEV_FLOAT(&dx[1], stride, vl);
103+
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);
104+
105+
v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl);
106+
v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl);
107+
108+
VSSEV_FLOAT(&dx[1], stride, v_dx, vl);
109+
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
110+
}
111+
#endif
63112
goto L140;
64113
L30:
65114
dh11 = dparam[2];
66115
dh22 = dparam[5];
67116
i__2 = nsteps;
68117
i__1 = incx;
118+
#if !defined(RISCV_SIMD)
69119
for (i__ = 1; i__1 < 0 ? i__ >= i__2 : i__ <= i__2; i__ += i__1) {
70120
w = dx[i__];
71121
z__ = dy[i__];
72122
dx[i__] = w * dh11 + z__;
73123
dy[i__] = -w + dh22 * z__;
74124
/* L40: */
75125
}
126+
#else
127+
if(i__1 < 0){
128+
offset = i__2 - 2;
129+
dx += offset;
130+
dy += offset;
131+
i__1 = -i__1;
132+
i__2 = -i__2;
133+
}
134+
stride = i__1 * sizeof(FLOAT);
135+
n = i__2 / i__1;
136+
for (size_t vl; n > 0; n -= vl, dx += vl*i__1, dy += vl*i__1) {
137+
vl = VSETVL(n);
138+
139+
v_w = VLSEV_FLOAT(&dx[1], stride, vl);
140+
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);
141+
142+
v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl);
143+
v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl);
144+
145+
VSSEV_FLOAT(&dx[1], stride, v_dx, vl);
146+
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
147+
}
148+
#endif
76149
goto L140;
77150
L50:
78151
dh11 = dparam[2];
@@ -81,13 +154,39 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
81154
dh22 = dparam[5];
82155
i__1 = nsteps;
83156
i__2 = incx;
157+
#if !defined(RISCV_SIMD)
84158
for (i__ = 1; i__2 < 0 ? i__ >= i__1 : i__ <= i__1; i__ += i__2) {
85159
w = dx[i__];
86160
z__ = dy[i__];
87161
dx[i__] = w * dh11 + z__ * dh12;
88162
dy[i__] = w * dh21 + z__ * dh22;
89163
/* L60: */
90164
}
165+
#else
166+
if(i__2 < 0){
167+
offset = i__1 - 2;
168+
dx += offset;
169+
dy += offset;
170+
i__1 = -i__1;
171+
i__2 = -i__2;
172+
}
173+
stride = i__2 * sizeof(FLOAT);
174+
n = i__1 / i__2;
175+
for (size_t vl; n > 0; n -= vl, dx += vl*i__2, dy += vl*i__2) {
176+
vl = VSETVL(n);
177+
178+
v_w = VLSEV_FLOAT(&dx[1], stride, vl);
179+
v_z__ = VLSEV_FLOAT(&dy[1], stride, vl);
180+
181+
v_dx = VFMULVF_FLOAT(v_w, dh11, vl);
182+
v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl);
183+
VSSEV_FLOAT(&dx[1], stride, v_dx, vl);
184+
185+
v_dy = VFMULVF_FLOAT(v_w, dh21, vl);
186+
v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl);
187+
VSSEV_FLOAT(&dy[1], stride, v_dy, vl);
188+
}
189+
#endif
91190
goto L140;
92191
L70:
93192
kx = 1;
@@ -110,6 +209,7 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
110209
dh12 = dparam[4];
111210
dh21 = dparam[3];
112211
i__2 = n;
212+
#if !defined(RISCV_SIMD)
113213
for (i__ = 1; i__ <= i__2; ++i__) {
114214
w = dx[kx];
115215
z__ = dy[ky];
@@ -119,11 +219,36 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
119219
ky += incy;
120220
/* L90: */
121221
}
222+
#else
223+
if(incx < 0){
224+
incx = -incx;
225+
dx -= n*incx;
226+
}
227+
if(incy < 0){
228+
incy = -incy;
229+
dy -= n*incy;
230+
}
231+
stride_x = incx * sizeof(FLOAT);
232+
stride_y = incy * sizeof(FLOAT);
233+
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
234+
vl = VSETVL(n);
235+
236+
v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
237+
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);
238+
239+
v_dx = VFMACCVF_FLOAT(v_w, dh12, v_z__, vl);
240+
v_dy = VFMACCVF_FLOAT(v_z__, dh21, v_w, vl);
241+
242+
VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);
243+
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
244+
}
245+
#endif
122246
goto L140;
123247
L100:
124248
dh11 = dparam[2];
125249
dh22 = dparam[5];
126250
i__2 = n;
251+
#if !defined(RISCV_SIMD)
127252
for (i__ = 1; i__ <= i__2; ++i__) {
128253
w = dx[kx];
129254
z__ = dy[ky];
@@ -133,8 +258,33 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
133258
ky += incy;
134259
/* L110: */
135260
}
261+
#else
262+
if(incx < 0){
263+
incx = -incx;
264+
dx -= n*incx;
265+
}
266+
if(incy < 0){
267+
incy = -incy;
268+
dy -= n*incy;
269+
}
270+
stride_x = incx * sizeof(FLOAT);
271+
stride_y = incy * sizeof(FLOAT);
272+
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
273+
vl = VSETVL(n);
274+
275+
v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
276+
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);
277+
278+
v_dx = VFMACCVF_FLOAT(v_z__, dh11, v_w, vl);
279+
v_dy = VFMSACVF_FLOAT(v_w, dh22, v_z__, vl);
280+
281+
VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);
282+
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
283+
}
284+
#endif
136285
goto L140;
137286
L120:
287+
#if !defined(RISCV_SIMD)
138288
dh11 = dparam[2];
139289
dh12 = dparam[4];
140290
dh21 = dparam[3];
@@ -149,6 +299,32 @@ void CNAME(blasint n, FLOAT *dx, blasint incx, FLOAT *dy, blasint incy, FLOAT *d
149299
ky += incy;
150300
/* L130: */
151301
}
302+
#else
303+
if(incx < 0){
304+
incx = -incx;
305+
dx -= n*incx;
306+
}
307+
if(incy < 0){
308+
incy = -incy;
309+
dy -= n*incy;
310+
}
311+
stride_x = incx * sizeof(FLOAT);
312+
stride_y = incy * sizeof(FLOAT);
313+
for (size_t vl; n > 0; n -= vl, dx += vl*incx, dy += vl*incy) {
314+
vl = VSETVL(n);
315+
316+
v_w = VLSEV_FLOAT(&dx[kx], stride_x, vl);
317+
v_z__ = VLSEV_FLOAT(&dy[ky], stride_y, vl);
318+
319+
v_dx = VFMULVF_FLOAT(v_w, dh11, vl);
320+
v_dx = VFMACCVF_FLOAT(v_dx, dh12, v_z__, vl);
321+
VSSEV_FLOAT(&dx[kx], stride_x, v_dx, vl);
322+
323+
v_dy = VFMULVF_FLOAT(v_w, dh21, vl);
324+
v_dy = VFMACCVF_FLOAT(v_dy, dh22, v_z__, vl);
325+
VSSEV_FLOAT(&dy[ky], stride_y, v_dy, vl);
326+
}
327+
#endif
152328
L140:
153329
return;
154330
}

utest/test_rot.c

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,42 @@ CTEST(rot,drot_inc_0)
5353
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS);
5454
}
5555
}
56+
CTEST(rot,drot_inc_1)
57+
{
58+
blasint i=0;
59+
blasint N=4,incX=1,incY=1;
60+
double c=1.0,s=1.0;
61+
double x1[]={1.0,3.0,5.0,7.0};
62+
double y1[]={2.0,4.0,6.0,8.0};
63+
double x2[]={3.0,7.0,11.0,15.0};
64+
double y2[]={1.0,1.0,1.0,1.0};
65+
66+
//OpenBLAS
67+
BLASFUNC(drot)(&N,x1,&incX,y1,&incY,&c,&s);
68+
69+
for(i=0; i<N; i++){
70+
ASSERT_DBL_NEAR_TOL(x2[i], x1[i], DOUBLE_EPS);
71+
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], DOUBLE_EPS);
72+
}
73+
}
74+
CTEST(rot,drotm_inc_1)
75+
{
76+
blasint i = 0;
77+
blasint N = 12, incX = 1, incY = 1;
78+
double param[5] = {1.0, 2.0, 3.0, 4.0, 5.0};
79+
double x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
80+
double y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
81+
double x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0};
82+
double y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0};
83+
84+
//OpenBLAS
85+
BLASFUNC(drotm)(&N, x_actual, &incX, y_actual, &incY, param);
86+
87+
for(i = 0; i < N; i++){
88+
ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], DOUBLE_EPS);
89+
ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], DOUBLE_EPS);
90+
}
91+
}
5692
#endif
5793

5894
#ifdef BUILD_COMPLEX16
@@ -96,6 +132,42 @@ CTEST(rot,srot_inc_0)
96132
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS);
97133
}
98134
}
135+
CTEST(rot,srot_inc_1)
136+
{
137+
blasint i=0;
138+
blasint N=4,incX=1,incY=1;
139+
float c=1.0,s=1.0;
140+
float x1[]={1.0,3.0,5.0,7.0};
141+
float y1[]={2.0,4.0,6.0,8.0};
142+
float x2[]={3.0,7.0,11.0,15.0};
143+
float y2[]={1.0,1.0,1.0,1.0};
144+
145+
//OpenBLAS
146+
BLASFUNC(srot)(&N,x1,&incX,y1,&incY,&c,&s);
147+
148+
for(i=0; i<N; i++){
149+
ASSERT_DBL_NEAR_TOL(x2[i], x1[i], SINGLE_EPS);
150+
ASSERT_DBL_NEAR_TOL(y2[i], y1[i], SINGLE_EPS);
151+
}
152+
}
153+
CTEST(rot,srotm_inc_1)
154+
{
155+
blasint i = 0;
156+
blasint N = 12, incX = 1, incY = 1;
157+
float param[5] = {1.0, 2.0, 3.0, 4.0, 5.0};
158+
float x_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
159+
float y_actual[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0};
160+
float x_referece[] = {3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0, 27.0, 30.0, 33.0, 36.0};
161+
float y_referece[] = {4.0, 8.0, 12.0, 16.0, 20.0, 24.0, 28.0, 32.0, 36.0, 40.0, 44.0, 48.0};
162+
163+
//OpenBLAS
164+
BLASFUNC(srotm)(&N, x_actual, &incX, y_actual, &incY, param);
165+
166+
for(i = 0; i < N; i++){
167+
ASSERT_DBL_NEAR_TOL(x_referece[i], x_actual[i], SINGLE_EPS);
168+
ASSERT_DBL_NEAR_TOL(y_referece[i], y_actual[i], SINGLE_EPS);
169+
}
170+
}
99171
#endif
100172

101173
#ifdef BUILD_COMPLEX

0 commit comments

Comments
 (0)