Skip to content

Commit 8c472ef

Browse files
committed
Further tweak small GEMM for AArch64
1 parent 7a6fa69 commit 8c472ef

9 files changed

+771
-2285
lines changed

kernel/arm64/dgemm_small_kernel_nn_sve.c

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,27 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4646
})
4747
#endif
4848

49-
#define A_ELEMENT_K(m, offset_k) A[(i + (m)) + (k + offset_k) * lda]
49+
#define RESET_A_POINTER() a_offset = A;
50+
51+
#define CREATE_A_POINTER(m, scale) FLOAT* a_offset##m = a_offset + scale;
52+
#define UPDATE_A_POINTER(scale) a_offset = a_offset + scale;
53+
#define A_ELEMENT_K(m, offset_k) *(a_offset##m + (k + offset_k) * lda)
5054
#define A_ELEMENT(m) A_ELEMENT_K(m, 0)
5155

52-
#define B_ELEMENT_K(n, offset_k) B[(k + offset_k) + (j + (n)) * ldb]
56+
#define RESET_B_POINTER() b_offset = B;
57+
58+
#define CREATE_B_POINTER(n, scale) FLOAT* b_offset##n = b_offset + scale * ldb;
59+
#define UPDATE_B_POINTER(scale) b_offset = b_offset + scale * ldb;
60+
#define B_ELEMENT_K(n, offset_k) *(b_offset##n + (k + offset_k))
5361
#define B_ELEMENT(n) B_ELEMENT_K(n, 0)
5462

55-
#define C_ELEMENT(m, n) C[(i + (m)) + (j + (n)) * ldc]
63+
#define CREATE_C_POINTER(n, scale) FLOAT* c_offset##n = c_offset + scale * ldc;
64+
#define INCR_C_POINTER(m, incr) // c_offset ## m += incr;
65+
#define UPDATE_C_POINTER(scale) c_offset = c_offset + scale * ldc;
66+
#define C_ELEMENT(m, n) *(c_offset##n + ((m * v_size) + i))
67+
68+
// #undef C_ELEMENT
69+
// #define C_ELEMENT(m, n) C[(i+(m))+(j+(n))*ldc]
5670

5771
#define PACK_ELEMENT_K(n, offset_k) packed_b[(k + offset_k) * 4 + n]
5872
#define PACK_ELEMENT(n) PACK_ELEMENT_K(n, 0)
@@ -112,8 +126,7 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
112126
#define BROADCAST_LOAD_B(n, offset_k) \
113127
svfloat64_t b##s##n##_k##offset_k = svdup_f64(B_ELEMENT_K(n, offset_k));
114128
#define VECTOR_LOAD_A(pg, m, offset_k) \
115-
svfloat64_t a##s##m##_k##offset_k = \
116-
svld1(pg, &A_ELEMENT_K(v_size * m, offset_k));
129+
svfloat64_t a##s##m##_k##offset_k = svld1(pg, &A_ELEMENT_K(m, offset_k));
117130
#define QUADWORD_LOAD_B(n, offset_k) \
118131
svfloat64_t b##s##n##_k##offset_k = \
119132
svld1rq(pg_true, &B_ELEMENT_K(n, offset_k));
@@ -140,26 +153,23 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
140153
#ifdef B0
141154
#define VECTOR_STORE(pg, m, n) \
142155
result##m##n = svmul_m(pg, result##m##n, alpha_vec); \
143-
svst1(pg, &C_ELEMENT(v_size* m, n), result##m##n);
156+
svst1(pg, &C_ELEMENT(m, n), result##m##n);
144157
#define SCATTER_STORE(pg, m, n) \
145158
result##m##n = svmul_m(pg, result##m##n, alpha_vec); \
146-
svst1_scatter_index( \
147-
pg, &C_ELEMENT(v_size* m, n), svindex_u64(0LL, ldc), result##m##n);
159+
svst1_scatter_index(pg, &C_ELEMENT(m, n), ldc_vec, result##m##n);
148160
#else
149161
#define VECTOR_STORE(pg, m, n) \
150162
result##m##n = svmul_m(pg, result##m##n, alpha_vec); \
151163
result##m##n = \
152-
svmla_m(pg, result##m##n, svld1(pg, &C_ELEMENT(v_size * m, n)), beta_vec); \
153-
svst1(pg, &C_ELEMENT(v_size* m, n), result##m##n);
164+
svmla_m(pg, result##m##n, svld1(pg, &C_ELEMENT(m, n)), beta_vec); \
165+
svst1(pg, &C_ELEMENT(m, n), result##m##n);
154166
#define SCATTER_STORE(pg, m, n) \
155167
result##m##n = svmul_m(pg, result##m##n, alpha_vec); \
156-
result##m##n = svmla_m( \
157-
pg, \
158-
result##m##n, \
159-
svld1_gather_index(pg, &C_ELEMENT(v_size * m, n), svindex_u64(0LL, ldc)), \
160-
beta_vec); \
161-
svst1_scatter_index( \
162-
pg, &C_ELEMENT(v_size* m, n), svindex_u64(0LL, ldc), result##m##n);
168+
result##m##n = svmla_m(pg, \
169+
result##m##n, \
170+
svld1_gather_index(pg, &C_ELEMENT(m, n), ldc_vec), \
171+
beta_vec); \
172+
svst1_scatter_index(pg, &C_ELEMENT(m, n), ldc_vec, result##m##n);
163173
#endif
164174

165175
#ifndef LIKELY
@@ -169,13 +179,6 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
169179
#define LIKELY(x) (x)
170180
#endif
171181
#endif
172-
#ifndef UNLIKELY
173-
#ifdef __GNUC__
174-
#define UNLIKELY(x) __builtin_expect(!!(x), 0)
175-
#else
176-
#define UNLIKELY(x) (x)
177-
#endif
178-
#endif
179182

180183
#ifdef B0
181184
int
@@ -223,12 +226,29 @@ CNAME(BLASLONG M,
223226
FLOAT* packed_b =
224227
(pack_b) ? packed_b = (FLOAT*)malloc(K * 4 * sizeof(FLOAT)) : NULL;
225228

229+
FLOAT* b_offset = B;
230+
FLOAT* a_offset = A;
231+
FLOAT* c_offset = C;
232+
226233
BLASLONG j = 0;
227234
for (; j < n4; j += 4) {
228235

236+
CREATE_C_POINTER(0, 0);
237+
CREATE_C_POINTER(1, 1);
238+
CREATE_C_POINTER(2, 2);
239+
CREATE_C_POINTER(3, 3);
240+
CREATE_B_POINTER(0, 0);
241+
CREATE_B_POINTER(1, 1);
242+
CREATE_B_POINTER(2, 2);
243+
CREATE_B_POINTER(3, 3);
244+
229245
BLASLONG i = 0;
230246
for (; i < v_m2; i += v_size2) {
231247

248+
CREATE_A_POINTER(0, 0);
249+
CREATE_A_POINTER(1, v_size);
250+
UPDATE_A_POINTER(v_size2);
251+
232252
BLASLONG k = 0;
233253
DECLARE_RESULT_VECTOR(0, 0);
234254
DECLARE_RESULT_VECTOR(0, 1);
@@ -372,9 +392,16 @@ CNAME(BLASLONG M,
372392
VECTOR_STORE(pg_true, 1, 1);
373393
VECTOR_STORE(pg_true, 1, 2);
374394
VECTOR_STORE(pg_true, 1, 3);
395+
INCR_C_POINTER(0, v_size2);
396+
INCR_C_POINTER(1, v_size2);
397+
INCR_C_POINTER(2, v_size2);
398+
INCR_C_POINTER(3, v_size2);
375399
}
376400
for (; i < v_m1; i += v_size) {
377401

402+
CREATE_A_POINTER(0, 0);
403+
UPDATE_A_POINTER(v_size);
404+
378405
BLASLONG k = 0;
379406
DECLARE_RESULT_VECTOR(0, 0);
380407
DECLARE_RESULT_VECTOR(0, 1);
@@ -431,9 +458,15 @@ CNAME(BLASLONG M,
431458
VECTOR_STORE(pg_true, 0, 1);
432459
VECTOR_STORE(pg_true, 0, 2);
433460
VECTOR_STORE(pg_true, 0, 3);
461+
INCR_C_POINTER(0, v_size);
462+
INCR_C_POINTER(1, v_size);
463+
INCR_C_POINTER(2, v_size);
464+
INCR_C_POINTER(3, v_size);
434465
}
435466
for (; i < M; i += v_size) {
436467
const svbool_t pg_tail = svwhilelt_b64((uint64_t)i, (uint64_t)(M));
468+
CREATE_A_POINTER(0, 0);
469+
UPDATE_A_POINTER(0);
437470

438471
BLASLONG k = 0;
439472
DECLARE_RESULT_VECTOR(0, 0);
@@ -491,13 +524,30 @@ CNAME(BLASLONG M,
491524
VECTOR_STORE(pg_tail, 0, 1);
492525
VECTOR_STORE(pg_tail, 0, 2);
493526
VECTOR_STORE(pg_tail, 0, 3);
527+
INCR_C_POINTER(0, 0);
528+
INCR_C_POINTER(1, 0);
529+
INCR_C_POINTER(2, 0);
530+
INCR_C_POINTER(3, 0);
494531
}
532+
533+
UPDATE_B_POINTER(4);
534+
RESET_A_POINTER();
535+
UPDATE_C_POINTER(4);
495536
}
496537
for (; j < n2; j += 2) {
497538

539+
CREATE_C_POINTER(0, 0);
540+
CREATE_C_POINTER(1, 1);
541+
CREATE_B_POINTER(0, 0);
542+
CREATE_B_POINTER(1, 1);
543+
498544
BLASLONG i = 0;
499545
for (; i < v_m2; i += v_size2) {
500546

547+
CREATE_A_POINTER(0, 0);
548+
CREATE_A_POINTER(1, v_size);
549+
UPDATE_A_POINTER(v_size2);
550+
501551
BLASLONG k = 0;
502552
DECLARE_RESULT_VECTOR(0, 0);
503553
DECLARE_RESULT_VECTOR(0, 1);
@@ -538,9 +588,14 @@ CNAME(BLASLONG M,
538588
VECTOR_STORE(pg_true, 0, 1);
539589
VECTOR_STORE(pg_true, 1, 0);
540590
VECTOR_STORE(pg_true, 1, 1);
591+
INCR_C_POINTER(0, v_size2);
592+
INCR_C_POINTER(1, v_size2);
541593
}
542594
for (; i < v_m1; i += v_size) {
543595

596+
CREATE_A_POINTER(0, 0);
597+
UPDATE_A_POINTER(v_size);
598+
544599
BLASLONG k = 0;
545600
DECLARE_RESULT_VECTOR(0, 0);
546601
DECLARE_RESULT_VECTOR(0, 1);
@@ -568,9 +623,13 @@ CNAME(BLASLONG M,
568623
}
569624
VECTOR_STORE(pg_true, 0, 0);
570625
VECTOR_STORE(pg_true, 0, 1);
626+
INCR_C_POINTER(0, v_size);
627+
INCR_C_POINTER(1, v_size);
571628
}
572629
for (; i < M; i += v_size) {
573630
const svbool_t pg_tail = svwhilelt_b64((uint64_t)i, (uint64_t)(M));
631+
CREATE_A_POINTER(0, 0);
632+
UPDATE_A_POINTER(0);
574633

575634
BLASLONG k = 0;
576635
DECLARE_RESULT_VECTOR(0, 0);
@@ -599,13 +658,26 @@ CNAME(BLASLONG M,
599658
}
600659
VECTOR_STORE(pg_tail, 0, 0);
601660
VECTOR_STORE(pg_tail, 0, 1);
661+
INCR_C_POINTER(0, 0);
662+
INCR_C_POINTER(1, 0);
602663
}
664+
665+
UPDATE_B_POINTER(2);
666+
RESET_A_POINTER();
667+
UPDATE_C_POINTER(2);
603668
}
604669
for (; j < N; j++) {
605670

671+
CREATE_C_POINTER(0, 0);
672+
CREATE_B_POINTER(0, 0);
673+
606674
BLASLONG i = 0;
607675
for (; i < v_m2; i += v_size2) {
608676

677+
CREATE_A_POINTER(0, 0);
678+
CREATE_A_POINTER(1, v_size);
679+
UPDATE_A_POINTER(v_size2);
680+
609681
BLASLONG k = 0;
610682
DECLARE_RESULT_VECTOR(0, 0);
611683
DECLARE_RESULT_VECTOR(1, 0);
@@ -620,9 +692,13 @@ CNAME(BLASLONG M,
620692
}
621693
VECTOR_STORE(pg_true, 0, 0);
622694
VECTOR_STORE(pg_true, 1, 0);
695+
INCR_C_POINTER(0, v_size2);
623696
}
624697
for (; i < v_m1; i += v_size) {
625698

699+
CREATE_A_POINTER(0, 0);
700+
UPDATE_A_POINTER(v_size);
701+
626702
BLASLONG k = 0;
627703
DECLARE_RESULT_VECTOR(0, 0);
628704

@@ -633,9 +709,12 @@ CNAME(BLASLONG M,
633709
UPDATE_RESULT_VECTOR(pg_true, 0, 0, 0);
634710
}
635711
VECTOR_STORE(pg_true, 0, 0);
712+
INCR_C_POINTER(0, v_size);
636713
}
637714
for (; i < M; i += v_size) {
638715
const svbool_t pg_tail = svwhilelt_b64((uint64_t)i, (uint64_t)(M));
716+
CREATE_A_POINTER(0, 0);
717+
UPDATE_A_POINTER(0);
639718

640719
BLASLONG k = 0;
641720
DECLARE_RESULT_VECTOR(0, 0);
@@ -647,11 +726,16 @@ CNAME(BLASLONG M,
647726
UPDATE_RESULT_VECTOR(pg_tail, 0, 0, 0);
648727
}
649728
VECTOR_STORE(pg_tail, 0, 0);
729+
INCR_C_POINTER(0, 0);
650730
}
731+
732+
UPDATE_B_POINTER(1);
733+
RESET_A_POINTER();
734+
UPDATE_C_POINTER(1);
651735
}
652736

653737
if (pack_b)
654738
free(packed_b);
655739

656740
return 0;
657-
}
741+
}

0 commit comments

Comments
 (0)