Skip to content

Commit b1c9faf

Browse files
committed
Remove k2 loop from DGEMM TN and use a more conservative heuristic for SGEMM
1 parent 8c472ef commit b1c9faf

File tree

2 files changed

+2
-211
lines changed

2 files changed

+2
-211
lines changed

kernel/arm64/dgemm_small_kernel_tn_sve.c

Lines changed: 1 addition & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -265,43 +265,7 @@ CNAME(BLASLONG M,
265265

266266
if (LIKELY(packed_a != NULL)) {
267267
if (j == 0) {
268-
for (; k < k2; k += 2) {
269-
270-
VECTOR_LOAD_B_K2(0, 0);
271-
VECTOR_LOAD_B_K2(1, 0);
272-
TRANSPOSE_B2_K2(0, 1, 0, 1);
273-
SCALE_B2_K2(0, 0, 1);
274-
GATHER_LOAD_A(pg_true, 0, 0);
275-
VECTOR_PACK_A(0, 0);
276-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
277-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
278-
GATHER_LOAD_A(pg_true, 0, 1);
279-
VECTOR_PACK_A(0, 1);
280-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
281-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
282-
VECTOR_LOAD_B_K2(2, 0);
283-
VECTOR_LOAD_B_K2(3, 0);
284-
TRANSPOSE_B2_K2(2, 3, 0, 1);
285-
SCALE_B2_K2(2, 0, 1);
286-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 0);
287-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 0);
288-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 1);
289-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 1);
290-
GATHER_LOAD_A(pg_true, 1, 0);
291-
VECTOR_PACK_A(1, 0);
292-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
293-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
294-
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 2, 0, 0);
295-
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 2, 1, 0);
296-
GATHER_LOAD_A(pg_true, 1, 1);
297-
VECTOR_PACK_A(1, 1);
298-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 1);
299-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 1);
300-
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 2, 0, 1);
301-
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 2, 1, 1);
302-
}
303268
for (; k < K; k++) {
304-
305269
BROADCAST_LOAD_B(0, 0);
306270
GATHER_LOAD_A(pg_true, 0, 0);
307271
VECTOR_PACK_A(0, 0);
@@ -320,39 +284,7 @@ CNAME(BLASLONG M,
320284
UPDATE_RESULT_VECTOR(pg_true, 1, 3, 0);
321285
}
322286
} else {
323-
for (; k < k2; k += 2) {
324-
325-
VECTOR_LOAD_B_K2(0, 0);
326-
VECTOR_LOAD_B_K2(1, 0);
327-
TRANSPOSE_B2_K2(0, 1, 0, 1);
328-
SCALE_B2_K2(0, 0, 1);
329-
UNPACK_VECTOR_A(0, 0);
330-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
331-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
332-
UNPACK_VECTOR_A(0, 1);
333-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
334-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
335-
VECTOR_LOAD_B_K2(2, 0);
336-
VECTOR_LOAD_B_K2(3, 0);
337-
TRANSPOSE_B2_K2(2, 3, 0, 1);
338-
SCALE_B2_K2(2, 0, 1);
339-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 0);
340-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 0);
341-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 1);
342-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 1);
343-
UNPACK_VECTOR_A(1, 0);
344-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
345-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
346-
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 2, 0, 0);
347-
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 2, 1, 0);
348-
UNPACK_VECTOR_A(1, 1);
349-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 1);
350-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 1);
351-
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 2, 0, 1);
352-
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 2, 1, 1);
353-
}
354287
for (; k < K; k++) {
355-
356288
BROADCAST_LOAD_B(0, 0);
357289
UNPACK_VECTOR_A(0, 0);
358290
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -370,37 +302,6 @@ CNAME(BLASLONG M,
370302
}
371303
}
372304
} else {
373-
for (; k < k2; k += 2) {
374-
375-
VECTOR_LOAD_B_K2(0, 0);
376-
VECTOR_LOAD_B_K2(1, 0);
377-
TRANSPOSE_B2_K2(0, 1, 0, 1);
378-
SCALE_B2_K2(0, 0, 1);
379-
GATHER_LOAD_A(pg_true, 0, 0);
380-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
381-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
382-
GATHER_LOAD_A(pg_true, 0, 1);
383-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
384-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
385-
VECTOR_LOAD_B_K2(2, 0);
386-
VECTOR_LOAD_B_K2(3, 0);
387-
TRANSPOSE_B2_K2(2, 3, 0, 1);
388-
SCALE_B2_K2(2, 0, 1);
389-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 0);
390-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 0);
391-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 1);
392-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 1);
393-
GATHER_LOAD_A(pg_true, 1, 0);
394-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
395-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
396-
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 2, 0, 0);
397-
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 2, 1, 0);
398-
GATHER_LOAD_A(pg_true, 1, 1);
399-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 1);
400-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 1);
401-
UPDATE_RESULT_VECTOR_QUADWORD(1, 2, 2, 0, 1);
402-
UPDATE_RESULT_VECTOR_QUADWORD(1, 3, 2, 1, 1);
403-
}
404305
for (; k < K; k++) {
405306

406307
BROADCAST_LOAD_B(0, 0);
@@ -443,27 +344,7 @@ CNAME(BLASLONG M,
443344
DECLARE_RESULT_VECTOR(1, 1);
444345

445346
if (LIKELY(packed_a != NULL)) {
446-
for (; k < k2; k += 2) {
447-
448-
VECTOR_LOAD_B_K2(0, 0);
449-
VECTOR_LOAD_B_K2(1, 0);
450-
TRANSPOSE_B2_K2(0, 1, 0, 1);
451-
SCALE_B2_K2(0, 0, 1);
452-
UNPACK_VECTOR_A(0, 0);
453-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
454-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
455-
UNPACK_VECTOR_A(0, 1);
456-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
457-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
458-
UNPACK_VECTOR_A(1, 0);
459-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
460-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
461-
UNPACK_VECTOR_A(1, 1);
462-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 1);
463-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 1);
464-
}
465347
for (; k < K; k++) {
466-
467348
BROADCAST_LOAD_B(0, 0);
468349
UNPACK_VECTOR_A(0, 0);
469350
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -474,27 +355,7 @@ CNAME(BLASLONG M,
474355
UPDATE_RESULT_VECTOR(pg_true, 1, 1, 0);
475356
}
476357
} else {
477-
for (; k < k2; k += 2) {
478-
479-
VECTOR_LOAD_B_K2(0, 0);
480-
VECTOR_LOAD_B_K2(1, 0);
481-
TRANSPOSE_B2_K2(0, 1, 0, 1);
482-
SCALE_B2_K2(0, 0, 1);
483-
GATHER_LOAD_A(pg_true, 0, 0);
484-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
485-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
486-
GATHER_LOAD_A(pg_true, 0, 1);
487-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
488-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
489-
GATHER_LOAD_A(pg_true, 1, 0);
490-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 0);
491-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 0);
492-
GATHER_LOAD_A(pg_true, 1, 1);
493-
UPDATE_RESULT_VECTOR_QUADWORD(1, 0, 0, 0, 1);
494-
UPDATE_RESULT_VECTOR_QUADWORD(1, 1, 0, 1, 1);
495-
}
496358
for (; k < K; k++) {
497-
498359
BROADCAST_LOAD_B(0, 0);
499360
GATHER_LOAD_A(pg_true, 0, 0);
500361
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
@@ -570,27 +431,6 @@ CNAME(BLASLONG M,
570431
DECLARE_RESULT_VECTOR(0, 2);
571432
DECLARE_RESULT_VECTOR(0, 3);
572433

573-
for (; k < k2; k += 2) {
574-
575-
VECTOR_LOAD_B_K2(0, 0);
576-
VECTOR_LOAD_B_K2(1, 0);
577-
TRANSPOSE_B2_K2(0, 1, 0, 1);
578-
SCALE_B2_K2(0, 0, 1);
579-
GATHER_LOAD_A(pg_true, 0, 0);
580-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
581-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
582-
GATHER_LOAD_A(pg_true, 0, 1);
583-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
584-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
585-
VECTOR_LOAD_B_K2(2, 0);
586-
VECTOR_LOAD_B_K2(3, 0);
587-
TRANSPOSE_B2_K2(2, 3, 0, 1);
588-
SCALE_B2_K2(2, 0, 1);
589-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 0);
590-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 0);
591-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 1);
592-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 1);
593-
}
594434
for (; k < K; k++) {
595435

596436
BROADCAST_LOAD_B(0, 0);
@@ -619,19 +459,6 @@ CNAME(BLASLONG M,
619459
DECLARE_RESULT_VECTOR(0, 0);
620460
DECLARE_RESULT_VECTOR(0, 1);
621461

622-
for (; k < k2; k += 2) {
623-
624-
VECTOR_LOAD_B_K2(0, 0);
625-
VECTOR_LOAD_B_K2(1, 0);
626-
TRANSPOSE_B2_K2(0, 1, 0, 1);
627-
SCALE_B2_K2(0, 0, 1);
628-
GATHER_LOAD_A(pg_true, 0, 0);
629-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
630-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
631-
GATHER_LOAD_A(pg_true, 0, 1);
632-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
633-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
634-
}
635462
for (; k < K; k++) {
636463

637464
BROADCAST_LOAD_B(0, 0);
@@ -686,27 +513,6 @@ CNAME(BLASLONG M,
686513
DECLARE_RESULT_VECTOR(0, 2);
687514
DECLARE_RESULT_VECTOR(0, 3);
688515

689-
for (; k < k2; k += 2) {
690-
691-
VECTOR_LOAD_B_K2(0, 0);
692-
VECTOR_LOAD_B_K2(1, 0);
693-
TRANSPOSE_B2_K2(0, 1, 0, 1);
694-
SCALE_B2_K2(0, 0, 1);
695-
GATHER_LOAD_A(pg_tail, 0, 0);
696-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
697-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
698-
GATHER_LOAD_A(pg_tail, 0, 1);
699-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
700-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
701-
VECTOR_LOAD_B_K2(2, 0);
702-
VECTOR_LOAD_B_K2(3, 0);
703-
TRANSPOSE_B2_K2(2, 3, 0, 1);
704-
SCALE_B2_K2(2, 0, 1);
705-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 0);
706-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 0);
707-
UPDATE_RESULT_VECTOR_QUADWORD(0, 2, 2, 0, 1);
708-
UPDATE_RESULT_VECTOR_QUADWORD(0, 3, 2, 1, 1);
709-
}
710516
for (; k < K; k++) {
711517

712518
BROADCAST_LOAD_B(0, 0);
@@ -735,19 +541,6 @@ CNAME(BLASLONG M,
735541
DECLARE_RESULT_VECTOR(0, 0);
736542
DECLARE_RESULT_VECTOR(0, 1);
737543

738-
for (; k < k2; k += 2) {
739-
740-
VECTOR_LOAD_B_K2(0, 0);
741-
VECTOR_LOAD_B_K2(1, 0);
742-
TRANSPOSE_B2_K2(0, 1, 0, 1);
743-
SCALE_B2_K2(0, 0, 1);
744-
GATHER_LOAD_A(pg_tail, 0, 0);
745-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 0);
746-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 0);
747-
GATHER_LOAD_A(pg_tail, 0, 1);
748-
UPDATE_RESULT_VECTOR_QUADWORD(0, 0, 0, 0, 1);
749-
UPDATE_RESULT_VECTOR_QUADWORD(0, 1, 0, 1, 1);
750-
}
751544
for (; k < K; k++) {
752545

753546
BROADCAST_LOAD_B(0, 0);
@@ -787,4 +580,4 @@ CNAME(BLASLONG M,
787580
free(packed_a);
788581

789582
return 0;
790-
}
583+
}

kernel/arm64/gemm_small_kernel_permit_sve.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alph
3535
if (MNK <= 64*64*64)
3636
return 1;
3737
#else // sgemm
38-
if (MNK <= 256*256*256)
38+
if (MNK <= 64*64*64)
3939
return 1;
4040
#endif
4141

42-
43-
4442
return 0;
4543
}

0 commit comments

Comments
 (0)