Skip to content

Commit 22aa81f

Browse files
committed
s390x: fix cscal and zscal implementations
The implementation of complex scalar * vector multiplication for Z14 makes some LAPACK tests fail because the numerical differences to the reference implementation exceed the threshold (as can be seen by running make lapack-test and replacing kernel/zarch/cscal.c with a generic implementation for comparison). The complex multiplication uses terms of the form a * b + c * d for both real and imaginary parts. The assembly code (and compiler-emitted code as well) uses fused multiply add operations for the second product and sum. The results can be "surprising", for example when both terms in the imaginary part nearly cancel each other out. In that case, the second product contributes more digits to the sum than the first product that has been rounded before. One option is to use separate multiplications (which then round the same way) and a distinct add. Change the code to pursue that path, by (1) requesting the compiler not to contract the operations into FMAs and (2) replacing the assembly kernel with corresponding vectorized C code (where change 1 also applies). Signed-off-by: Marius Hillenbrand <[email protected]>
1 parent 77ea73f commit 22aa81f

File tree

2 files changed

+60
-130
lines changed

2 files changed

+60
-130
lines changed

kernel/zarch/cscal.c

Lines changed: 30 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -25,67 +25,35 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
2525
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*****************************************************************************/
2727

28+
/*
29+
* Avoid contraction of floating point operations, specifically fused
30+
* multiply-add, because they can cause unexpected results in complex
31+
* multiplication.
32+
*/
33+
#if defined(__GNUC__) && !defined(__clang__)
34+
#pragma GCC optimize ("fp-contract=off")
35+
#endif
36+
37+
#if defined(__clang__)
38+
#pragma clang fp contract(off)
39+
#endif
40+
2841
#include "common.h"
42+
#include "vector-common.h"
2943

30-
static void cscal_kernel_16(BLASLONG n, FLOAT *alpha, FLOAT *x) {
31-
__asm__("vlrepf %%v0,0(%[alpha])\n\t"
32-
"vlef %%v1,4(%[alpha]),0\n\t"
33-
"vlef %%v1,4(%[alpha]),2\n\t"
34-
"vflcsb %%v1,%%v1\n\t"
35-
"vlef %%v1,4(%[alpha]),1\n\t"
36-
"vlef %%v1,4(%[alpha]),3\n\t"
37-
"srlg %[n],%[n],4\n\t"
38-
"xgr %%r1,%%r1\n\t"
39-
"0:\n\t"
40-
"pfd 2, 1024(%%r1,%[x])\n\t"
41-
"vl %%v16,0(%%r1,%[x])\n\t"
42-
"vl %%v17,16(%%r1,%[x])\n\t"
43-
"vl %%v18,32(%%r1,%[x])\n\t"
44-
"vl %%v19,48(%%r1,%[x])\n\t"
45-
"vl %%v20,64(%%r1,%[x])\n\t"
46-
"vl %%v21,80(%%r1,%[x])\n\t"
47-
"vl %%v22,96(%%r1,%[x])\n\t"
48-
"vl %%v23,112(%%r1,%[x])\n\t"
49-
"verllg %%v24,%%v16,32\n\t"
50-
"verllg %%v25,%%v17,32\n\t"
51-
"verllg %%v26,%%v18,32\n\t"
52-
"verllg %%v27,%%v19,32\n\t"
53-
"verllg %%v28,%%v20,32\n\t"
54-
"verllg %%v29,%%v21,32\n\t"
55-
"verllg %%v30,%%v22,32\n\t"
56-
"verllg %%v31,%%v23,32\n\t"
57-
"vfmsb %%v16,%%v16,%%v0\n\t"
58-
"vfmsb %%v17,%%v17,%%v0\n\t"
59-
"vfmsb %%v18,%%v18,%%v0\n\t"
60-
"vfmsb %%v19,%%v19,%%v0\n\t"
61-
"vfmsb %%v20,%%v20,%%v0\n\t"
62-
"vfmsb %%v21,%%v21,%%v0\n\t"
63-
"vfmsb %%v22,%%v22,%%v0\n\t"
64-
"vfmsb %%v23,%%v23,%%v0\n\t"
65-
"vfmasb %%v16,%%v24,%%v1,%%v16\n\t"
66-
"vfmasb %%v17,%%v25,%%v1,%%v17\n\t"
67-
"vfmasb %%v18,%%v26,%%v1,%%v18\n\t"
68-
"vfmasb %%v19,%%v27,%%v1,%%v19\n\t"
69-
"vfmasb %%v20,%%v28,%%v1,%%v20\n\t"
70-
"vfmasb %%v21,%%v29,%%v1,%%v21\n\t"
71-
"vfmasb %%v22,%%v30,%%v1,%%v22\n\t"
72-
"vfmasb %%v23,%%v31,%%v1,%%v23\n\t"
73-
"vst %%v16,0(%%r1,%[x])\n\t"
74-
"vst %%v17,16(%%r1,%[x])\n\t"
75-
"vst %%v18,32(%%r1,%[x])\n\t"
76-
"vst %%v19,48(%%r1,%[x])\n\t"
77-
"vst %%v20,64(%%r1,%[x])\n\t"
78-
"vst %%v21,80(%%r1,%[x])\n\t"
79-
"vst %%v22,96(%%r1,%[x])\n\t"
80-
"vst %%v23,112(%%r1,%[x])\n\t"
81-
"agfi %%r1,128\n\t"
82-
"brctg %[n],0b"
83-
: "+m"(*(FLOAT (*)[n * 2]) x),[n] "+&r"(n)
84-
: [x] "a"(x), "m"(*(const FLOAT (*)[2]) alpha),
85-
[alpha] "a"(alpha)
86-
: "cc", "r1", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21",
87-
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
88-
"v31");
44+
static void cscal_kernel_16(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x) {
45+
vector_float da_r_vec = vec_splats(da_r);
46+
vector_float da_i_vec = { -da_i, da_i, -da_i, da_i };
47+
48+
vector_float *x_vec_ptr = (vector_float *)x;
49+
50+
#pragma GCC unroll 16
51+
for (size_t i = 0; i < n/2; i++) {
52+
vector_float x_vec = vec_load_hinted(x + i * VLEN_FLOATS);
53+
vector_float x_swapped = {x_vec[1], x_vec[0], x_vec[3], x_vec[2]};
54+
55+
x_vec_ptr[i] = x_vec * da_r_vec + x_swapped * da_i_vec;
56+
}
8957
}
9058

9159
static void cscal_kernel_16_zero_r(BLASLONG n, FLOAT *alpha, FLOAT *x) {
@@ -199,14 +167,12 @@ static void cscal_kernel_16_zero(BLASLONG n, FLOAT *x) {
199167
: "cc", "r1", "v0");
200168
}
201169

202-
static void cscal_kernel_inc_8(BLASLONG n, FLOAT *alpha, FLOAT *x,
170+
static void cscal_kernel_inc_8(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x,
203171
BLASLONG inc_x) {
204172
BLASLONG i;
205173
BLASLONG inc_x2 = 2 * inc_x;
206174
BLASLONG inc_x3 = inc_x2 + inc_x;
207175
FLOAT t0, t1, t2, t3;
208-
FLOAT da_r = alpha[0];
209-
FLOAT da_i = alpha[1];
210176

211177
for (i = 0; i < n; i += 4) {
212178
t0 = da_r * x[0] - da_i * x[1];
@@ -324,9 +290,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
324290

325291
BLASLONG n1 = n & -8;
326292
if (n1 > 0) {
327-
alpha[0] = da_r;
328-
alpha[1] = da_i;
329-
cscal_kernel_inc_8(n1, alpha, x, inc_x);
293+
cscal_kernel_inc_8(n1, da_r, da_i, x, inc_x);
330294
j = n1;
331295
i = n1 * inc_x;
332296
}
@@ -362,7 +326,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
362326
else if (da_i == 0)
363327
cscal_kernel_16_zero_i(n1, alpha, x);
364328
else
365-
cscal_kernel_16(n1, alpha, x);
329+
cscal_kernel_16(n1, da_r, da_i, x);
366330

367331
i = n1 << 1;
368332
j = n1;

kernel/zarch/zscal.c

Lines changed: 30 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -25,65 +25,35 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
2525
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*****************************************************************************/
2727

28+
/*
29+
* Avoid contraction of floating point operations, specifically fused
30+
* multiply-add, because they can cause unexpected results in complex
31+
* multiplication.
32+
*/
33+
#if defined(__GNUC__) && !defined(__clang__)
34+
#pragma GCC optimize ("fp-contract=off")
35+
#endif
36+
37+
#if defined(__clang__)
38+
#pragma clang fp contract(off)
39+
#endif
40+
2841
#include "common.h"
42+
#include "vector-common.h"
2943

30-
static void zscal_kernel_8(BLASLONG n, FLOAT *alpha, FLOAT *x) {
31-
__asm__("vlrepg %%v0,0(%[alpha])\n\t"
32-
"vleg %%v1,8(%[alpha]),0\n\t"
33-
"wflcdb %%v1,%%v1\n\t"
34-
"vleg %%v1,8(%[alpha]),1\n\t"
35-
"srlg %[n],%[n],3\n\t"
36-
"xgr %%r1,%%r1\n\t"
37-
"0:\n\t"
38-
"pfd 2, 1024(%%r1,%[x])\n\t"
39-
"vl %%v16,0(%%r1,%[x])\n\t"
40-
"vl %%v17,16(%%r1,%[x])\n\t"
41-
"vl %%v18,32(%%r1,%[x])\n\t"
42-
"vl %%v19,48(%%r1,%[x])\n\t"
43-
"vl %%v20,64(%%r1,%[x])\n\t"
44-
"vl %%v21,80(%%r1,%[x])\n\t"
45-
"vl %%v22,96(%%r1,%[x])\n\t"
46-
"vl %%v23,112(%%r1,%[x])\n\t"
47-
"vpdi %%v24,%%v16,%%v16,4\n\t"
48-
"vpdi %%v25,%%v17,%%v17,4\n\t"
49-
"vpdi %%v26,%%v18,%%v18,4\n\t"
50-
"vpdi %%v27,%%v19,%%v19,4\n\t"
51-
"vpdi %%v28,%%v20,%%v20,4\n\t"
52-
"vpdi %%v29,%%v21,%%v21,4\n\t"
53-
"vpdi %%v30,%%v22,%%v22,4\n\t"
54-
"vpdi %%v31,%%v23,%%v23,4\n\t"
55-
"vfmdb %%v16,%%v16,%%v0\n\t"
56-
"vfmdb %%v17,%%v17,%%v0\n\t"
57-
"vfmdb %%v18,%%v18,%%v0\n\t"
58-
"vfmdb %%v19,%%v19,%%v0\n\t"
59-
"vfmdb %%v20,%%v20,%%v0\n\t"
60-
"vfmdb %%v21,%%v21,%%v0\n\t"
61-
"vfmdb %%v22,%%v22,%%v0\n\t"
62-
"vfmdb %%v23,%%v23,%%v0\n\t"
63-
"vfmadb %%v16,%%v24,%%v1,%%v16\n\t"
64-
"vfmadb %%v17,%%v25,%%v1,%%v17\n\t"
65-
"vfmadb %%v18,%%v26,%%v1,%%v18\n\t"
66-
"vfmadb %%v19,%%v27,%%v1,%%v19\n\t"
67-
"vfmadb %%v20,%%v28,%%v1,%%v20\n\t"
68-
"vfmadb %%v21,%%v29,%%v1,%%v21\n\t"
69-
"vfmadb %%v22,%%v30,%%v1,%%v22\n\t"
70-
"vfmadb %%v23,%%v31,%%v1,%%v23\n\t"
71-
"vst %%v16,0(%%r1,%[x])\n\t"
72-
"vst %%v17,16(%%r1,%[x])\n\t"
73-
"vst %%v18,32(%%r1,%[x])\n\t"
74-
"vst %%v19,48(%%r1,%[x])\n\t"
75-
"vst %%v20,64(%%r1,%[x])\n\t"
76-
"vst %%v21,80(%%r1,%[x])\n\t"
77-
"vst %%v22,96(%%r1,%[x])\n\t"
78-
"vst %%v23,112(%%r1,%[x])\n\t"
79-
"agfi %%r1,128\n\t"
80-
"brctg %[n],0b"
81-
: "+m"(*(FLOAT (*)[n * 2]) x),[n] "+&r"(n)
82-
: [x] "a"(x), "m"(*(const FLOAT (*)[2]) alpha),
83-
[alpha] "a"(alpha)
84-
: "cc", "r1", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21",
85-
"v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
86-
"v31");
44+
static void zscal_kernel_8(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x) {
45+
vector_float da_r_vec = vec_splats(da_r);
46+
vector_float da_i_vec = { -da_i, da_i };
47+
48+
vector_float * x_vec_ptr = (vector_float *)x;
49+
50+
#pragma GCC unroll 16
51+
for (size_t i = 0; i < n; i++) {
52+
vector_float x_vec = vec_load_hinted(x + i * VLEN_FLOATS);
53+
vector_float x_swapped = {x_vec[1], x_vec[0]};
54+
55+
x_vec_ptr[i] = x_vec * da_r_vec + x_swapped * da_i_vec;
56+
}
8757
}
8858

8959
static void zscal_kernel_8_zero_r(BLASLONG n, FLOAT *alpha, FLOAT *x) {
@@ -195,14 +165,12 @@ static void zscal_kernel_8_zero(BLASLONG n, FLOAT *x) {
195165
: "cc", "r1", "v0");
196166
}
197167

198-
static void zscal_kernel_inc_8(BLASLONG n, FLOAT *alpha, FLOAT *x,
168+
static void zscal_kernel_inc_8(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x,
199169
BLASLONG inc_x) {
200170
BLASLONG i;
201171
BLASLONG inc_x2 = 2 * inc_x;
202172
BLASLONG inc_x3 = inc_x2 + inc_x;
203173
FLOAT t0, t1, t2, t3;
204-
FLOAT da_r = alpha[0];
205-
FLOAT da_i = alpha[1];
206174

207175
for (i = 0; i < n; i += 4) {
208176
t0 = da_r * x[0] - da_i * x[1];
@@ -320,9 +288,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
320288

321289
BLASLONG n1 = n & -8;
322290
if (n1 > 0) {
323-
alpha[0] = da_r;
324-
alpha[1] = da_i;
325-
zscal_kernel_inc_8(n1, alpha, x, inc_x);
291+
zscal_kernel_inc_8(n1, da_r, da_i, x, inc_x);
326292
j = n1;
327293
i = n1 * inc_x;
328294
}
@@ -358,7 +324,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
358324
else if (da_i == 0)
359325
zscal_kernel_8_zero_i(n1, alpha, x);
360326
else
361-
zscal_kernel_8(n1, alpha, x);
327+
zscal_kernel_8(n1, da_r, da_i, x);
362328

363329
i = n1 << 1;
364330
j = n1;

0 commit comments

Comments
 (0)