Skip to content

Commit 6aaa107

Browse files
author
Tim Moon
committed
Reducing threads for multi-threaded GEMMs on small matrices.
1 parent 2ccd7f6 commit 6aaa107

File tree

1 file changed

+9
-56
lines changed

1 file changed

+9
-56
lines changed

driver/level3/level3_thread.c

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,6 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
684684
BLASLONG m = args -> m;
685685
BLASLONG n = args -> n;
686686
BLASLONG nthreads = args -> nthreads;
687-
BLASLONG divN, divT;
688-
int mode;
689687

690688
if (nthreads == 1) {
691689
GEMM_LOCAL(args, range_m, range_n, sa, sb, 0);
@@ -706,66 +704,21 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO
706704
n = n_to - n_from;
707705
}
708706

709-
if ((m < nthreads * SWITCH_RATIO) || (n < nthreads * SWITCH_RATIO)) {
707+
if ((m < 2 * SWITCH_RATIO) || (n < 2 * SWITCH_RATIO)) {
710708
GEMM_LOCAL(args, range_m, range_n, sa, sb, 0);
711709
return 0;
712710
}
713711

714-
divT = nthreads;
715-
divN = 1;
716-
717-
#if 0
718-
while ((GEMM_P * divT > m * SWITCH_RATIO) && (divT > 1)) {
719-
do {
720-
divT --;
721-
divN = 1;
722-
while (divT * divN < nthreads) divN ++;
723-
} while ((divT * divN != nthreads) && (divT > 1));
712+
if (m < nthreads * SWITCH_RATIO) {
713+
nthreads = blas_quickdivide(m, SWITCH_RATIO);
724714
}
725-
#endif
726-
727-
// fprintf(stderr, "divN = %4ld divT = %4ld\n", divN, divT);
728-
729-
args -> nthreads = divT;
730-
731-
if (divN == 1){
732-
733-
gemm_driver(args, range_m, range_n, sa, sb, 0);
734-
} else {
735-
#ifndef COMPLEX
736-
#ifdef XDOUBLE
737-
mode = BLAS_XDOUBLE | BLAS_REAL;
738-
#elif defined(DOUBLE)
739-
mode = BLAS_DOUBLE | BLAS_REAL;
740-
#else
741-
mode = BLAS_SINGLE | BLAS_REAL;
742-
#endif
743-
#else
744-
#ifdef XDOUBLE
745-
mode = BLAS_XDOUBLE | BLAS_COMPLEX;
746-
#elif defined(DOUBLE)
747-
mode = BLAS_DOUBLE | BLAS_COMPLEX;
748-
#else
749-
mode = BLAS_SINGLE | BLAS_COMPLEX;
750-
#endif
751-
#endif
752-
753-
#if defined(TN) || defined(TT) || defined(TR) || defined(TC) || \
754-
defined(CN) || defined(CT) || defined(CR) || defined(CC)
755-
mode |= (BLAS_TRANSA_T);
756-
#endif
757-
#if defined(NT) || defined(TT) || defined(RT) || defined(CT) || \
758-
defined(NC) || defined(TC) || defined(RC) || defined(CC)
759-
mode |= (BLAS_TRANSB_T);
760-
#endif
761-
762-
#ifdef OS_WINDOWS
763-
gemm_thread_n(mode, args, range_m, range_n, GEMM_LOCAL, sa, sb, divN);
764-
#else
765-
gemm_thread_n(mode, args, range_m, range_n, gemm_driver, sa, sb, divN);
766-
#endif
767-
715+
if (n < nthreads * SWITCH_RATIO) {
716+
nthreads = blas_quickdivide(n, SWITCH_RATIO);
768717
}
769718

719+
args -> nthreads = nthreads;
720+
721+
gemm_driver(args, range_m, range_n, sa, sb, 0);
722+
770723
return 0;
771724
}

0 commit comments

Comments
 (0)