Skip to content

Commit fca86e3

Browse files
authored
Merge pull request #4887 from goplanid/develop
Small GEMM improvements for AArch64 with SVE
2 parents 60c1519 + 4894c54 commit fca86e3

File tree

1 file changed

+222
-6
lines changed

1 file changed

+222
-6
lines changed

kernel/arm64/dgemm_small_kernel_tn_sve.c

Lines changed: 222 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ CNAME(BLASLONG M,
211211
const BLASLONG v_m1 = M & -v_size;
212212
const BLASLONG n4 = N & -4;
213213
const BLASLONG n2 = N & -2;
214+
const BLASLONG n8 = N & -8;
214215

215216
const int pack_a = M >= v_size2 && N >= 8 && K >= 8 ? 1 : 0;
216217
FLOAT* packed_a =
@@ -229,28 +230,37 @@ CNAME(BLASLONG M,
229230
CREATE_A_POINTER(1, v_size);
230231

231232
BLASLONG j = 0;
232-
for (; j < n4; j += 4) {
233-
233+
for (; j < n8; j += 8) {
234234
CREATE_B_POINTER(0, 0);
235235
CREATE_B_POINTER(1, 1);
236236
CREATE_B_POINTER(2, 2);
237237
CREATE_B_POINTER(3, 3);
238-
UPDATE_B_POINTER(4);
238+
CREATE_B_POINTER(4, 4);
239+
CREATE_B_POINTER(5, 5);
240+
CREATE_B_POINTER(6, 6);
241+
CREATE_B_POINTER(7, 7);
242+
UPDATE_B_POINTER(8);
239243

240244
BLASLONG k = 0;
241245
DECLARE_RESULT_VECTOR(0, 0);
242246
DECLARE_RESULT_VECTOR(0, 1);
243247
DECLARE_RESULT_VECTOR(0, 2);
244248
DECLARE_RESULT_VECTOR(0, 3);
249+
DECLARE_RESULT_VECTOR(0, 4);
250+
DECLARE_RESULT_VECTOR(0, 5);
251+
DECLARE_RESULT_VECTOR(0, 6);
252+
DECLARE_RESULT_VECTOR(0, 7);
245253
DECLARE_RESULT_VECTOR(1, 0);
246254
DECLARE_RESULT_VECTOR(1, 1);
247255
DECLARE_RESULT_VECTOR(1, 2);
248256
DECLARE_RESULT_VECTOR(1, 3);
249-
257+
DECLARE_RESULT_VECTOR(1, 4);
258+
DECLARE_RESULT_VECTOR(1, 5);
259+
DECLARE_RESULT_VECTOR(1, 6);
260+
DECLARE_RESULT_VECTOR(1, 7);
250261
if (LIKELY(packed_a != NULL)) {
251262
if (j == 0) {
252263
for (; k < K; k++) {
253-
254264
BROADCAST_LOAD_B(0, 0);
255265
GATHER_LOAD_A(pg_true, 0, 0);
256266
VECTOR_PACK_A(0, 0);
@@ -267,10 +277,21 @@ CNAME(BLASLONG M,
267277
BROADCAST_LOAD_B(3, 0);
268278
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
269279
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
280+
BROADCAST_LOAD_B(4, 0);
281+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
282+
UPDATE_RESULT_VECTOR(pg_true, 1, 4, 0);
283+
BROADCAST_LOAD_B(5, 0);
284+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
285+
UPDATE_RESULT_VECTOR(pg_true, 1, 5, 0);
286+
BROADCAST_LOAD_B(6, 0);
287+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
288+
UPDATE_RESULT_VECTOR(pg_true, 1, 6, 0);
289+
BROADCAST_LOAD_B(7, 0);
290+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
291+
UPDATE_RESULT_VECTOR(pg_true, 1, 7, 0);
270292
}
271293
} else {
272294
for (; k < K; k++) {
273-
274295
BROADCAST_LOAD_B(0, 0);
275296
UNPACK_VECTOR_A(0, 0);
276297
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -285,7 +306,104 @@ CNAME(BLASLONG M,
285306
BROADCAST_LOAD_B(3, 0);
286307
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
287308
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
309+
BROADCAST_LOAD_B(4, 0);
310+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
311+
UPDATE_RESULT_VECTOR(pg_true, 1, 4, 0);
312+
BROADCAST_LOAD_B(5, 0);
313+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
314+
UPDATE_RESULT_VECTOR(pg_true, 1, 5, 0);
315+
BROADCAST_LOAD_B(6, 0);
316+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
317+
UPDATE_RESULT_VECTOR(pg_true, 1, 6, 0);
318+
BROADCAST_LOAD_B(7, 0);
319+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
320+
UPDATE_RESULT_VECTOR(pg_true, 1, 7, 0);
288321
}
322+
}
323+
} else {
324+
for (; k < K; k++) {
325+
BROADCAST_LOAD_B(0, 0);
326+
GATHER_LOAD_A(pg_true, 0, 0);
327+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
328+
BROADCAST_LOAD_B(1, 0);
329+
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
330+
GATHER_LOAD_A(pg_true, 1, 0);
331+
UPDATE_RESULT_VECTOR(pg_true, 1, 0, 0);
332+
UPDATE_RESULT_VECTOR(pg_true, 1, 1, 0);
333+
BROADCAST_LOAD_B(2, 0);
334+
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
335+
UPDATE_RESULT_VECTOR(pg_true, 1, 2, 0);
336+
BROADCAST_LOAD_B(3, 0);
337+
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
338+
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
339+
BROADCAST_LOAD_B(4, 0);
340+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
341+
UPDATE_RESULT_VECTOR(pg_true, 1, 4, 0);
342+
BROADCAST_LOAD_B(5, 0);
343+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
344+
UPDATE_RESULT_VECTOR(pg_true, 1, 5, 0);
345+
BROADCAST_LOAD_B(6, 0);
346+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
347+
UPDATE_RESULT_VECTOR(pg_true, 1, 6, 0);
348+
BROADCAST_LOAD_B(7, 0);
349+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
350+
UPDATE_RESULT_VECTOR(pg_true, 1, 7, 0);
351+
}
352+
}
353+
VECTOR_STORE(pg_true, 0, 0);
354+
VECTOR_STORE(pg_true, 0, 1);
355+
VECTOR_STORE(pg_true, 0, 2);
356+
VECTOR_STORE(pg_true, 0, 3);
357+
VECTOR_STORE(pg_true, 0, 4);
358+
VECTOR_STORE(pg_true, 0, 5);
359+
VECTOR_STORE(pg_true, 0, 6);
360+
VECTOR_STORE(pg_true, 0, 7);
361+
VECTOR_STORE(pg_true, 1, 0);
362+
VECTOR_STORE(pg_true, 1, 1);
363+
VECTOR_STORE(pg_true, 1, 2);
364+
VECTOR_STORE(pg_true, 1, 3);
365+
VECTOR_STORE(pg_true, 1, 4);
366+
VECTOR_STORE(pg_true, 1, 5);
367+
VECTOR_STORE(pg_true, 1, 6);
368+
VECTOR_STORE(pg_true, 1, 7);
369+
INCR_C_POINTER(0, 8);
370+
INCR_C_POINTER(1, 8);
371+
}
372+
for (; j < n4; j += 4) {
373+
374+
CREATE_B_POINTER(0, 0);
375+
CREATE_B_POINTER(1, 1);
376+
CREATE_B_POINTER(2, 2);
377+
CREATE_B_POINTER(3, 3);
378+
UPDATE_B_POINTER(4);
379+
380+
BLASLONG k = 0;
381+
DECLARE_RESULT_VECTOR(0, 0);
382+
DECLARE_RESULT_VECTOR(0, 1);
383+
DECLARE_RESULT_VECTOR(0, 2);
384+
DECLARE_RESULT_VECTOR(0, 3);
385+
DECLARE_RESULT_VECTOR(1, 0);
386+
DECLARE_RESULT_VECTOR(1, 1);
387+
DECLARE_RESULT_VECTOR(1, 2);
388+
DECLARE_RESULT_VECTOR(1, 3);
389+
390+
if (LIKELY(packed_a != NULL)) {
391+
for (; k < K; k++) {
392+
393+
BROADCAST_LOAD_B(0, 0);
394+
UNPACK_VECTOR_A(0, 0);
395+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
396+
BROADCAST_LOAD_B(1, 0);
397+
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
398+
UNPACK_VECTOR_A(1, 0);
399+
UPDATE_RESULT_VECTOR(pg_true, 1, 0, 0);
400+
UPDATE_RESULT_VECTOR(pg_true, 1, 1, 0);
401+
BROADCAST_LOAD_B(2, 0);
402+
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
403+
UPDATE_RESULT_VECTOR(pg_true, 1, 2, 0);
404+
BROADCAST_LOAD_B(3, 0);
405+
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
406+
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
289407
}
290408
} else {
291409
for (; k < K; k++) {
@@ -405,6 +523,55 @@ CNAME(BLASLONG M,
405523
CREATE_A_POINTER(0, 0);
406524

407525
BLASLONG j = 0;
526+
for (; j < n8; j += 8) {
527+
CREATE_B_POINTER(0, 0);
528+
CREATE_B_POINTER(1, 1);
529+
CREATE_B_POINTER(2, 2);
530+
CREATE_B_POINTER(3, 3);
531+
CREATE_B_POINTER(4, 4);
532+
CREATE_B_POINTER(5, 5);
533+
CREATE_B_POINTER(6, 6);
534+
CREATE_B_POINTER(7, 7);
535+
UPDATE_B_POINTER(8);
536+
537+
BLASLONG k = 0;
538+
DECLARE_RESULT_VECTOR(0, 0);
539+
DECLARE_RESULT_VECTOR(0, 1);
540+
DECLARE_RESULT_VECTOR(0, 2);
541+
DECLARE_RESULT_VECTOR(0, 3);
542+
DECLARE_RESULT_VECTOR(0, 4);
543+
DECLARE_RESULT_VECTOR(0, 5);
544+
DECLARE_RESULT_VECTOR(0, 6);
545+
DECLARE_RESULT_VECTOR(0, 7);
546+
for (; k < K; k++) {
547+
BROADCAST_LOAD_B(0, 0);
548+
GATHER_LOAD_A(pg_true, 0, 0);
549+
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
550+
BROADCAST_LOAD_B(1, 0);
551+
UPDATE_RESULT_VECTOR(pg_true, 0, 1, 0);
552+
BROADCAST_LOAD_B(2, 0);
553+
UPDATE_RESULT_VECTOR(pg_true, 0, 2, 0);
554+
BROADCAST_LOAD_B(3, 0);
555+
UPDATE_RESULT_VECTOR(pg_true, 0, 3, 0);
556+
BROADCAST_LOAD_B(4, 0);
557+
UPDATE_RESULT_VECTOR(pg_true, 0, 4, 0);
558+
BROADCAST_LOAD_B(5, 0);
559+
UPDATE_RESULT_VECTOR(pg_true, 0, 5, 0);
560+
BROADCAST_LOAD_B(6, 0);
561+
UPDATE_RESULT_VECTOR(pg_true, 0, 6, 0);
562+
BROADCAST_LOAD_B(7, 0);
563+
UPDATE_RESULT_VECTOR(pg_true, 0, 7, 0);
564+
}
565+
VECTOR_STORE(pg_true, 0, 0);
566+
VECTOR_STORE(pg_true, 0, 1);
567+
VECTOR_STORE(pg_true, 0, 2);
568+
VECTOR_STORE(pg_true, 0, 3);
569+
VECTOR_STORE(pg_true, 0, 4);
570+
VECTOR_STORE(pg_true, 0, 5);
571+
VECTOR_STORE(pg_true, 0, 6);
572+
VECTOR_STORE(pg_true, 0, 7);
573+
INCR_C_POINTER(0, 8);
574+
}
408575
for (; j < n4; j += 4) {
409576

410577
CREATE_B_POINTER(0, 0);
@@ -487,6 +654,55 @@ CNAME(BLASLONG M,
487654
CREATE_A_POINTER(0, 0);
488655

489656
BLASLONG j = 0;
657+
for (; j < n8; j += 8) {
658+
CREATE_B_POINTER(0, 0);
659+
CREATE_B_POINTER(1, 1);
660+
CREATE_B_POINTER(2, 2);
661+
CREATE_B_POINTER(3, 3);
662+
CREATE_B_POINTER(4, 4);
663+
CREATE_B_POINTER(5, 5);
664+
CREATE_B_POINTER(6, 6);
665+
CREATE_B_POINTER(7, 7);
666+
UPDATE_B_POINTER(8);
667+
668+
BLASLONG k = 0;
669+
DECLARE_RESULT_VECTOR(0, 0);
670+
DECLARE_RESULT_VECTOR(0, 1);
671+
DECLARE_RESULT_VECTOR(0, 2);
672+
DECLARE_RESULT_VECTOR(0, 3);
673+
DECLARE_RESULT_VECTOR(0, 4);
674+
DECLARE_RESULT_VECTOR(0, 5);
675+
DECLARE_RESULT_VECTOR(0, 6);
676+
DECLARE_RESULT_VECTOR(0, 7);
677+
for (; k < K; k++) {
678+
BROADCAST_LOAD_B(0, 0);
679+
GATHER_LOAD_A(pg_tail, 0, 0);
680+
UPDATE_RESULT_VECTOR(pg_tail, 0, 0, 0);
681+
BROADCAST_LOAD_B(1, 0);
682+
UPDATE_RESULT_VECTOR(pg_tail, 0, 1, 0);
683+
BROADCAST_LOAD_B(2, 0);
684+
UPDATE_RESULT_VECTOR(pg_tail, 0, 2, 0);
685+
BROADCAST_LOAD_B(3, 0);
686+
UPDATE_RESULT_VECTOR(pg_tail, 0, 3, 0);
687+
BROADCAST_LOAD_B(4, 0);
688+
UPDATE_RESULT_VECTOR(pg_tail, 0, 4, 0);
689+
BROADCAST_LOAD_B(5, 0);
690+
UPDATE_RESULT_VECTOR(pg_tail, 0, 5, 0);
691+
BROADCAST_LOAD_B(6, 0);
692+
UPDATE_RESULT_VECTOR(pg_tail, 0, 6, 0);
693+
BROADCAST_LOAD_B(7, 0);
694+
UPDATE_RESULT_VECTOR(pg_tail, 0, 7, 0);
695+
}
696+
VECTOR_STORE(pg_tail, 0, 0);
697+
VECTOR_STORE(pg_tail, 0, 1);
698+
VECTOR_STORE(pg_tail, 0, 2);
699+
VECTOR_STORE(pg_tail, 0, 3);
700+
VECTOR_STORE(pg_tail, 0, 4);
701+
VECTOR_STORE(pg_tail, 0, 5);
702+
VECTOR_STORE(pg_tail, 0, 6);
703+
VECTOR_STORE(pg_tail, 0, 7);
704+
INCR_C_POINTER(0, 8);
705+
}
490706
for (; j < n4; j += 4) {
491707

492708
CREATE_B_POINTER(0, 0);

0 commit comments

Comments
 (0)