Skip to content

Commit d23680b

Browse files
authored
Merge pull request #5407 from nakagawa-fj/feature/gemm_divide_rate_for_neoversev1
Multi-thread Performance Improvement of GEMM on NeoverseV1 with DIVIDE_RATE=1
2 parents b4cc4be + 7e29f11 commit d23680b

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

driver/level3/gemm.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@
6363
#define DIVIDE_RATE GEMM_DIVIDE_RATE
6464
#endif
6565

66+
#ifdef GEMM_DIVIDE_LIMIT
67+
#define DIVIDE_LIMIT GEMM_DIVIDE_LIMIT
68+
#endif
69+
6670
#ifdef THREADED_LEVEL3
6771
#include "level3_thread.c"
6872
#else

driver/level3/level3_thread.c

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
246246

247247
BLASLONG nthreads_m;
248248
BLASLONG mypos_m, mypos_n;
249+
BLASLONG divide_rate = DIVIDE_RATE;
249250

250251
BLASLONG is, js, ls, bufferside, jjs;
251252
BLASLONG min_i, min_l, div_n, min_jj;
@@ -280,6 +281,11 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
280281
alpha = (FLOAT *)args -> alpha;
281282
beta = (FLOAT *)args -> beta;
282283

284+
/* Disable divide_rate when N of all threads are less than to DIVIDE_LIMIT */
285+
#ifdef DIVIDE_LIMIT
286+
if (N < DIVIDE_LIMIT) divide_rate = 1;
287+
#endif
288+
283289
/* Initialize 2D CPU distribution */
284290
nthreads_m = args -> nthreads;
285291
if (range_m) {
@@ -321,9 +327,9 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
321327
) return 0;
322328

323329
/* Initialize workspace for local region of B */
324-
div_n = (n_to - n_from + DIVIDE_RATE - 1) / DIVIDE_RATE;
330+
div_n = (n_to - n_from + divide_rate - 1) / divide_rate;
325331
buffer[0] = sb;
326-
for (i = 1; i < DIVIDE_RATE; i++) {
332+
for (i = 1; i < divide_rate; i++) {
327333
buffer[i] = buffer[i - 1] + GEMM_Q * ((div_n + GEMM_UNROLL_N - 1)/GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE;
328334
}
329335

@@ -365,7 +371,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
365371
STOP_RPCC(copy_A);
366372

367373
/* Copy local region of B into workspace and apply kernel */
368-
div_n = (n_to - n_from + DIVIDE_RATE - 1) / DIVIDE_RATE;
374+
div_n = (n_to - n_from + divide_rate - 1) / divide_rate;
369375
for (js = n_from, bufferside = 0; js < n_to; js += div_n, bufferside ++) {
370376

371377
/* Make sure if no one is using workspace */
@@ -434,7 +440,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
434440
if (current >= (mypos_n + 1) * nthreads_m) current = mypos_n * nthreads_m;
435441

436442
/* Split other region of B into parts */
437-
div_n = (range_n[current + 1] - range_n[current] + DIVIDE_RATE - 1) / DIVIDE_RATE;
443+
div_n = (range_n[current + 1] - range_n[current] + divide_rate - 1) / divide_rate;
438444
for (js = range_n[current], bufferside = 0; js < range_n[current + 1]; js += div_n, bufferside ++) {
439445
if (current != mypos) {
440446

@@ -485,7 +491,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
485491
do {
486492

487493
/* Split region of B into parts and apply kernel */
488-
div_n = (range_n[current + 1] - range_n[current] + DIVIDE_RATE - 1) / DIVIDE_RATE;
494+
div_n = (range_n[current + 1] - range_n[current] + divide_rate - 1) / divide_rate;
489495
for (js = range_n[current], bufferside = 0; js < range_n[current + 1]; js += div_n, bufferside ++) {
490496

491497
/* Apply kernel with local region of A and part of region of B */
@@ -520,7 +526,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
520526
/* Wait until all other threads are done with local region of B */
521527
START_RPCC();
522528
for (i = 0; i < args -> nthreads; i++) {
523-
for (js = 0; js < DIVIDE_RATE; js++) {
529+
for (js = 0; js < divide_rate; js++) {
524530
while (job[mypos].working[i][CACHE_LINE_SIZE * js] ) {YIELDING;};
525531
}
526532
}

param.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3585,6 +3585,8 @@ is a big desktop or server with abundant cache rather than a phone or embedded d
35853585

35863586
#elif defined(NEOVERSEV1) // 256-bit SVE
35873587

3588+
#define GEMM_DIVIDE_LIMIT 3
3589+
35883590
#if defined(XDOUBLE) || defined(DOUBLE)
35893591
#define SWITCH_RATIO 8
35903592
#define GEMM_PREFERED_SIZE 4

0 commit comments

Comments
 (0)