Skip to content

Commit 1b25083

Browse files
committed
arm64: Fix nrm2 for input vectors with Inf
Fix double precision nrm2 kernels returning NaN when the input vectors contain Inf/-Inf.
1 parent cd898af commit 1b25083

File tree

4 files changed

+42
-19
lines changed

4 files changed

+42
-19
lines changed

kernel/arm64/KERNEL.NEOVERSEN1

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ IDAMAXKERNEL = iamax_thunderx2t99.c
9191
ICAMAXKERNEL = izamax_thunderx2t99.c
9292
IZAMAXKERNEL = izamax_thunderx2t99.c
9393

94-
SNRM2KERNEL = nrm2.S
95-
DNRM2KERNEL = nrm2.S
96-
CNRM2KERNEL = znrm2.S
97-
ZNRM2KERNEL = znrm2.S
94+
SNRM2KERNEL = scnrm2_thunderx2t99.c
95+
DNRM2KERNEL = dznrm2_thunderx2t99.c
96+
CNRM2KERNEL = scnrm2_thunderx2t99.c
97+
ZNRM2KERNEL = dznrm2_thunderx2t99.c
9898

9999
DDOTKERNEL = dot_thunderx2t99.c
100100
SDOTKERNEL = dot_thunderx2t99.c

kernel/arm64/KERNEL.THUNDERX2T99

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,12 @@ IDAMAXKERNEL = iamax_thunderx2t99.c
153153
ICAMAXKERNEL = izamax_thunderx2t99.c
154154
IZAMAXKERNEL = izamax_thunderx2t99.c
155155

156-
SNRM2KERNEL = nrm2.S
157-
CNRM2KERNEL = nrm2.S
156+
SNRM2KERNEL = scnrm2_thunderx2t99.c
157+
CNRM2KERNEL = scnrm2_thunderx2t99.c
158158
#DNRM2KERNEL = dznrm2_thunderx2t99_fast.c
159159
#ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c
160-
DNRM2KERNEL = znrm2.S
161-
ZNRM2KERNEL = znrm2.S
160+
DNRM2KERNEL = dznrm2_thunderx2t99.c
161+
ZNRM2KERNEL = dznrm2_thunderx2t99.c
162162

163163

164164
DDOTKERNEL = dot_thunderx2t99.c

kernel/arm64/KERNEL.THUNDERX3T110

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,13 @@ IDAMAXKERNEL = iamax_thunderx2t99.c
153153
ICAMAXKERNEL = izamax_thunderx2t99.c
154154
IZAMAXKERNEL = izamax_thunderx2t99.c
155155

156-
#SNRM2KERNEL = scnrm2_thunderx2t99.c
157-
#CNRM2KERNEL = scnrm2_thunderx2t99.c
158-
##DNRM2KERNEL = dznrm2_thunderx2t99_fast.c
159-
##ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c
160-
#DNRM2KERNEL = dznrm2_thunderx2t99.c
161-
#ZNRM2KERNEL = dznrm2_thunderx2t99.c
162-
SNRM2KERNEL = nrm2.S
163-
DNRM2KERNEL = nrm2.S
164-
CNRM2KERNEL = znrm2.S
165-
ZNRM2KERNEL = znrm2.S
156+
SNRM2KERNEL = scnrm2_thunderx2t99.c
157+
CNRM2KERNEL = scnrm2_thunderx2t99.c
158+
#DNRM2KERNEL = dznrm2_thunderx2t99_fast.c
159+
#ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c
160+
DNRM2KERNEL = dznrm2_thunderx2t99.c
161+
ZNRM2KERNEL = dznrm2_thunderx2t99.c
162+
166163

167164
DDOTKERNEL = dot_thunderx2t99.c
168165
SDOTKERNEL = dot_thunderx2t99.c

kernel/arm64/dznrm2_thunderx2t99.c

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n
5858
#define CUR_MAXINV "d8"
5959
#define CUR_MAXINV_V "v8.2d"
6060
#define CUR_MAX_V "v8.2d"
61+
#define REGINF "d9"
6162

6263
static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
6364
double *ssq, double *scale)
@@ -79,8 +80,10 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
7980
" ble 9f //nrm2_kernel_L999 \n"
8081

8182
"1: //nrm2_kernel_F_BEGIN: \n"
83+
" mov x6, #0x7FF0000000000000 //+Infinity \n"
8284
" fmov "REGZERO", xzr \n"
8385
" fmov "REGONE", #1.0 \n"
86+
" fmov "REGINF", x6 \n"
8487
" lsl "INC_X", "INC_X", #"INC_SHIFT" \n"
8588
" mov "J", "N" \n"
8689
" cmp "J", xzr \n"
@@ -104,6 +107,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
104107
" ldr d4, ["X"] \n"
105108
" fabs d4, d4 \n"
106109
" fmax "CUR_MAX", "SCALE", d4 \n"
110+
" fcmp "CUR_MAX", "REGINF" \n"
111+
" beq 10f \n"
107112
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
108113
" fmul "SCALE", "SCALE", "SCALE" \n"
109114
" fmul "SSQ", "SSQ", "SCALE" \n"
@@ -116,6 +121,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
116121
" ldr d3, ["X", #8] \n"
117122
" fabs d3, d3 \n"
118123
" fmax "CUR_MAX", "SCALE", d3 \n"
124+
" fcmp "CUR_MAX", "REGINF" \n"
125+
" beq 10f \n"
119126
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
120127
" fmul "SCALE", "SCALE", "SCALE" \n"
121128
" fmul "SSQ", "SSQ", "SCALE" \n"
@@ -158,6 +165,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
158165
" fmaxp v24.2d, v24.2d, v26.2d \n"
159166
" fmaxp v24.2d, v24.2d, v24.2d \n"
160167
" fmax "CUR_MAX", "SCALE", d24 \n"
168+
" fcmp "CUR_MAX", "REGINF" \n"
169+
" beq 10f \n"
161170
" fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n"
162171
" //dup "CUR_MAX_V", v7.d[0] \n"
163172
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
@@ -217,6 +226,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
217226
" fmaxp v24.2d, v24.2d, v26.2d \n"
218227
" fmaxp v24.2d, v24.2d, v24.2d \n"
219228
" fmax "CUR_MAX", "SCALE", d24 \n"
229+
" fcmp "CUR_MAX", "REGINF" \n"
230+
" beq 10f \n"
220231
" fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n"
221232
" //dup "CUR_MAX_V", v7.d[0] \n"
222233
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
@@ -265,6 +276,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
265276
" ldr d4, ["X"] \n"
266277
" fabs d4, d4 \n"
267278
" fmax "CUR_MAX", "SCALE", d4 \n"
279+
" fcmp "CUR_MAX", "REGINF" \n"
280+
" beq 10f \n"
268281
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
269282
" fmul "SCALE", "SCALE", "SCALE" \n"
270283
" fmul "SSQ", "SSQ", "SCALE" \n"
@@ -276,6 +289,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
276289
" ldr d3, ["X", #8] \n"
277290
" fabs d3, d3 \n"
278291
" fmax "CUR_MAX", "SCALE", d3 \n"
292+
" fcmp "CUR_MAX", "REGINF" \n"
293+
" beq 10f \n"
279294
" fdiv "SCALE", "SCALE", "CUR_MAX" \n"
280295
" fmul "SCALE", "SCALE", "SCALE" \n"
281296
" fmul "SSQ", "SSQ", "SCALE" \n"
@@ -291,6 +306,11 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
291306
"9: //nrm2_kernel_L999: \n"
292307
" str "SSQ", [%[SSQ_]] \n"
293308
" str "SCALE", [%[SCALE_]] \n"
309+
" b 11f \n"
310+
"10: \n"
311+
" str "REGINF", [%[SSQ_]] \n"
312+
" str "REGINF", [%[SCALE_]] \n"
313+
"11: \n"
294314

295315
:
296316
: [SSQ_] "r" (ssq), //%0
@@ -300,7 +320,7 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x,
300320
[INCX_] "r" (inc_x) //%4
301321
: "cc",
302322
"memory",
303-
"x0", "x1", "x2", "x3", "x4", "x5",
323+
"x0", "x1", "x2", "x3", "x4", "x5", "x6",
304324
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8"
305325
);
306326

@@ -359,6 +379,12 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x)
359379
cur_ssq = *ptr;
360380
cur_scale = *(ptr + 1);
361381

382+
if (cur_ssq == INFINITY) {
383+
ssq = INFINITY;
384+
scale = INFINITY;
385+
break;
386+
}
387+
362388
if (cur_scale != 0) {
363389
if (cur_scale > scale) {
364390
scale = (scale / cur_scale);

0 commit comments

Comments
 (0)