22
33#define MAX_THREADING_NAMES 32
44
5+ #ifndef max
6+ #define max (x , y ) ((x) > (y) ? (x) : (y))
7+ #endif
8+
9+ // We need to ask MKL to get/set threads for only the BLAS domain, we therefore pass in
10+ // this constant to the relevant threading functions to limit our thread setting.
11+ #define MKL_DOMAIN_BLAS 1 /* From mkl_types.h */
12+
513/*
614 * We provide a flexible thread getter/setter interface here; by calling `lbt_set_num_threads()`
715 * libblastrampoline will propagate the call through to its loaded libraries as long as the
1018 */
1119static char * getter_names [MAX_THREADING_NAMES ] = {
1220 "openblas_get_num_threads" ,
13- "MKL_Get_Max_Threads" ,
1421 "bli_thread_get_num_threads" ,
22+ // We special-case MKL in the lookup loop below
23+ //"MKL_Domain_Get_Max_Threads",
1524 NULL
1625};
1726
1827static char * setter_names [MAX_THREADING_NAMES ] = {
1928 "openblas_set_num_threads" ,
20- "MKL_Set_Num_Threads" ,
2129 "bli_thread_set_num_threads" ,
30+ // We special-case MKL in the lookup loop below
31+ //"MKL_Domain_Set_Num_Threads",
2232 NULL
2333};
2434
@@ -62,9 +72,16 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
6272 int (* fptr )() = lookup_symbol (lib -> handle , symbol_name );
6373 if (fptr != NULL ) {
6474 int new_threads = fptr ();
65- max_threads = max_threads > new_threads ? max_threads : new_threads ;
75+ max_threads = max ( max_threads , new_threads ) ;
6676 }
6777 }
78+
79+ // Special-case MKL, as we need to specifically ask for the "BLAS" domain
80+ int (* fptr )(int ) = lookup_symbol (lib -> handle , "MKL_Domain_Get_Max_Threads" );
81+ if (fptr != NULL ) {
82+ int new_threads = fptr (MKL_DOMAIN_BLAS );
83+ max_threads = max (max_threads , new_threads );
84+ }
6885 }
6986 return max_threads ;
7087}
@@ -76,15 +93,21 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
7693 */
7794LBT_DLLEXPORT void lbt_set_num_threads (int32_t nthreads ) {
7895 const lbt_config_t * config = lbt_get_config ();
96+ char symbol_name [MAX_SYMBOL_LEN ];
7997 for (int lib_idx = 0 ; config -> loaded_libs [lib_idx ] != NULL ; ++ lib_idx ) {
8098 lbt_library_info_t * lib = config -> loaded_libs [lib_idx ];
8199 for (int symbol_idx = 0 ; setter_names [symbol_idx ] != NULL ; ++ symbol_idx ) {
82- char symbol_name [MAX_SYMBOL_LEN ];
83100 build_symbol_name (symbol_name , setter_names [symbol_idx ], lib -> suffix );
84101 void (* fptr )(int ) = lookup_symbol (lib -> handle , symbol_name );
85102 if (fptr != NULL ) {
86103 fptr (nthreads );
87104 }
88105 }
106+
107+ // Special-case MKL, as we need to specifically ask for the "BLAS" domain
108+ int (* fptr )(int , int ) = lookup_symbol (lib -> handle , "MKL_Domain_Set_Num_Threads" );
109+ if (fptr != NULL ) {
110+ fptr (nthreads , MKL_DOMAIN_BLAS );
111+ }
89112 }
90113}
0 commit comments