Skip to content

Commit e816fd6

Browse files
authored
Merge pull request #158 from JuliaLinearAlgebra/im/mklthreads
MKL support: Set the LAPACK domain thread number
2 parents 44ecc1f + 6bcde51 commit e816fd6

File tree

1 file changed

+45
-7
lines changed

1 file changed

+45
-7
lines changed

src/threading.c

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,30 @@
66
#define max(x, y) ((x) > (y) ? (x) : (y))
77
#endif
88

9-
// We need to ask MKL to get/set threads for only the BLAS domain, we therefore pass in
10-
// this constant to the relevant threading functions to limit our thread setting.
11-
#define MKL_DOMAIN_BLAS 1 /* From mkl_types.h */
9+
/* We need to ask MKL to get/set threads for only the BLAS & LAPACK domains, we
10+
* therefore pass in this constant to the relevant threading functions to limit
11+
* our thread setting.
12+
*
13+
* The LAPACK threading domain was added in 2025.2, so we need to ask for the MKL
14+
* version in order to properly handle that.
15+
*
16+
* These are from mkl_types.h
17+
*/
18+
#define MKL_DOMAIN_BLAS 1
19+
#define MKL_DOMAIN_LAPACK 5
20+
21+
typedef
22+
struct {
23+
int MajorVersion;
24+
int MinorVersion;
25+
int UpdateVersion;
26+
int PatchVersion;
27+
char * ProductStatus;
28+
char * Build;
29+
char * Processor;
30+
char * Platform;
31+
} MKLVersion;
32+
1233

1334
/*
1435
* We provide a flexible thread getter/setter interface here; by calling `lbt_set_num_threads()`
@@ -76,11 +97,27 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
7697
}
7798
}
7899

79-
// Special-case MKL, as we need to specifically ask for the "BLAS" domain
100+
// Special-case MKL, as we need to specifically ask for the BLAS & LAPACK domains
80101
int (*fptr)(int) = lookup_symbol(lib->handle, "MKL_Domain_Get_Max_Threads");
81102
if (fptr != NULL) {
82-
int new_threads = fptr(MKL_DOMAIN_BLAS);
83-
max_threads = max(max_threads, new_threads);
103+
// The BLAS domain (always available)
104+
int new_threads_blas = fptr(MKL_DOMAIN_BLAS);
105+
max_threads = max(max_threads, new_threads_blas);
106+
107+
// The LAPACK threading domain was only added in oneMKL 2025.2
108+
// Gate reading the threads based on this version, because before 2025.2 it
109+
// will return the default, which wouldn't be useful because we max everything.
110+
void (*fverptr)(MKLVersion*) = lookup_symbol(lib->handle, "mkl_get_version");
111+
if (fverptr != NULL) {
112+
MKLVersion ver;
113+
fverptr(&ver);
114+
115+
// MKL considers 2025.2 to be a major of 2025 and an update of 2 (not a minor)
116+
if(ver.MajorVersion >= 2025 && ver.UpdateVersion >= 2) {
117+
int new_threads_lapack = fptr(MKL_DOMAIN_LAPACK);
118+
max_threads = max(max_threads, new_threads_lapack);
119+
}
120+
}
84121
}
85122
}
86123
return max_threads;
@@ -104,10 +141,11 @@ LBT_DLLEXPORT void lbt_set_num_threads(int32_t nthreads) {
104141
}
105142
}
106143

107-
// Special-case MKL, as we need to specifically ask for the "BLAS" domain
144+
// Special-case MKL, as we need to specifically ask for the BLAS & LAPACK domains
108145
int (*fptr)(int, int) = lookup_symbol(lib->handle, "MKL_Domain_Set_Num_Threads");
109146
if (fptr != NULL) {
110147
fptr(nthreads, MKL_DOMAIN_BLAS);
148+
fptr(nthreads, MKL_DOMAIN_LAPACK);
111149
}
112150
}
113151
}

0 commit comments

Comments
 (0)