66#define max (x , y ) ((x) > (y) ? (x) : (y))
77#endif
88
9- // We need to ask MKL to get/set threads for only the BLAS domain, we therefore pass in
10- // this constant to the relevant threading functions to limit our thread setting.
11- #define MKL_DOMAIN_BLAS 1 /* From mkl_types.h */
9+ /* We need to ask MKL to get/set threads for only the BLAS & LAPACK domains, we
10+ * therefore pass in this constant to the relevant threading functions to limit
11+ * our thread setting.
12+ *
13+ * The LAPACK threading domain was added in 2025.2, so we need to ask for the MKL
14+ * version in order to properly handle that.
15+ *
16+ * These are from mkl_types.h
17+ */
18+ #define MKL_DOMAIN_BLAS 1
19+ #define MKL_DOMAIN_LAPACK 5
20+
21+ typedef
22+ struct {
23+ int MajorVersion ;
24+ int MinorVersion ;
25+ int UpdateVersion ;
26+ int PatchVersion ;
27+ char * ProductStatus ;
28+ char * Build ;
29+ char * Processor ;
30+ char * Platform ;
31+ } MKLVersion ;
32+
1233
1334/*
1435 * We provide a flexible thread getter/setter interface here; by calling `lbt_set_num_threads()`
@@ -76,11 +97,27 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
7697 }
7798 }
7899
79- // Special-case MKL, as we need to specifically ask for the " BLAS" domain
100+ // Special-case MKL, as we need to specifically ask for the BLAS & LAPACK domains
80101 int (* fptr )(int ) = lookup_symbol (lib -> handle , "MKL_Domain_Get_Max_Threads" );
81102 if (fptr != NULL ) {
82- int new_threads = fptr (MKL_DOMAIN_BLAS );
83- max_threads = max (max_threads , new_threads );
103+ // The BLAS domain (always available)
104+ int new_threads_blas = fptr (MKL_DOMAIN_BLAS );
105+ max_threads = max (max_threads , new_threads_blas );
106+
107+ // The LAPACK threading domain was only added in oneMKL 2025.2
108+ // Gate reading the threads based on this version, because before 2025.2 it
109+ // will return the default, which wouldn't be useful because we max everything.
110+ void (* fverptr )(MKLVersion * ) = lookup_symbol (lib -> handle , "mkl_get_version" );
111+ if (fverptr != NULL ) {
112+ MKLVersion ver ;
113+ fverptr (& ver );
114+
115+ // MKL considers 2025.2 to be a major of 2025 and an update of 2 (not a minor)
116+ if (ver .MajorVersion >= 2025 && ver .UpdateVersion >= 2 ) {
117+ int new_threads_lapack = fptr (MKL_DOMAIN_LAPACK );
118+ max_threads = max (max_threads , new_threads_lapack );
119+ }
120+ }
84121 }
85122 }
86123 return max_threads ;
@@ -104,10 +141,11 @@ LBT_DLLEXPORT void lbt_set_num_threads(int32_t nthreads) {
104141 }
105142 }
106143
107- // Special-case MKL, as we need to specifically ask for the " BLAS" domain
144+ // Special-case MKL, as we need to specifically ask for the BLAS & LAPACK domains
108145 int (* fptr )(int , int ) = lookup_symbol (lib -> handle , "MKL_Domain_Set_Num_Threads" );
109146 if (fptr != NULL ) {
110147 fptr (nthreads , MKL_DOMAIN_BLAS );
148+ fptr (nthreads , MKL_DOMAIN_LAPACK );
111149 }
112150 }
113151}
0 commit comments