Skip to content

Commit ead47b5

Browse files
ViralBShahstaticfloat
authored andcommitted
Special-case MKL threading setters/getters
This uses the `MKL_Domain_*` functions to get/set the thread count for the BLAS domain only, so FFT, PARDISO, etc... domains in MKL are not affected bt `lbt_set_num_threads()`. It also adds a test to show that this behavior is reasonable when MKL is loaded.
1 parent 5a08240 commit ead47b5

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

src/threading.c

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
#define MAX_THREADING_NAMES 32
44

5+
#ifndef max
6+
#define max(x, y) ((x) > (y) ? (x) : (y))
7+
#endif
8+
9+
// We need to ask MKL to get/set threads for only the BLAS domain, we therefore pass in
10+
// this constant to the relevant threading functions to limit our thread setting.
11+
#define MKL_DOMAIN_BLAS 1 /* From mkl_types.h */
12+
513
/*
614
* We provide a flexible thread getter/setter interface here; by calling `lbt_set_num_threads()`
715
* libblastrampoline will propagate the call through to its loaded libraries as long as the
@@ -10,15 +18,17 @@
1018
*/
1119
static char * getter_names[MAX_THREADING_NAMES] = {
1220
"openblas_get_num_threads",
13-
"MKL_Get_Max_Threads",
1421
"bli_thread_get_num_threads",
22+
// We special-case MKL in the lookup loop below
23+
//"MKL_Domain_Get_Max_Threads",
1524
NULL
1625
};
1726

1827
static char * setter_names[MAX_THREADING_NAMES] = {
1928
"openblas_set_num_threads",
20-
"MKL_Set_Num_Threads",
2129
"bli_thread_set_num_threads",
30+
// We special-case MKL in the lookup loop below
31+
//"MKL_Domain_Set_Num_Threads",
2232
NULL
2333
};
2434

@@ -62,9 +72,16 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
6272
int (*fptr)() = lookup_symbol(lib->handle, symbol_name);
6373
if (fptr != NULL) {
6474
int new_threads = fptr();
65-
max_threads = max_threads > new_threads ? max_threads : new_threads;
75+
max_threads = max(max_threads, new_threads);
6676
}
6777
}
78+
79+
// Special-case MKL, as we need to specifically ask for the "BLAS" domain
80+
int (*fptr)(int) = lookup_symbol(lib->handle, "MKL_Domain_Get_Max_Threads");
81+
if (fptr != NULL) {
82+
int new_threads = fptr(MKL_DOMAIN_BLAS);
83+
max_threads = max(max_threads, new_threads);
84+
}
6885
}
6986
return max_threads;
7087
}
@@ -76,15 +93,21 @@ LBT_DLLEXPORT int32_t lbt_get_num_threads() {
7693
*/
7794
LBT_DLLEXPORT void lbt_set_num_threads(int32_t nthreads) {
7895
const lbt_config_t * config = lbt_get_config();
96+
char symbol_name[MAX_SYMBOL_LEN];
7997
for (int lib_idx=0; config->loaded_libs[lib_idx] != NULL; ++lib_idx) {
8098
lbt_library_info_t * lib = config->loaded_libs[lib_idx];
8199
for (int symbol_idx=0; setter_names[symbol_idx] != NULL; ++symbol_idx) {
82-
char symbol_name[MAX_SYMBOL_LEN];
83100
build_symbol_name(symbol_name, setter_names[symbol_idx], lib->suffix);
84101
void (*fptr)(int) = lookup_symbol(lib->handle, symbol_name);
85102
if (fptr != NULL) {
86103
fptr(nthreads);
87104
}
88105
}
106+
107+
// Special-case MKL, as we need to specifically ask for the "BLAS" domain
108+
int (*fptr)(int, int) = lookup_symbol(lib->handle, "MKL_Domain_Set_Num_Threads");
109+
if (fptr != NULL) {
110+
fptr(nthreads, MKL_DOMAIN_BLAS);
111+
}
89112
}
90113
}

test/direct.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,4 +331,16 @@ if MKL_jll.is_available() && Sys.ARCH == :x86_64
331331
@test result[1] ComplexF32(1.47 + 3.83im)
332332
@test isempty(stacktraces)
333333
end
334+
335+
@testset "MKL threading domains" begin
336+
nthreads = lbt_get_num_threads(lbt_handle)
337+
if nthreads <= 1
338+
nthreads = 2
339+
else
340+
nthreads = div(nthreads, 2)
341+
end
342+
lbt_set_num_threads(lbt_handle, nthreads)
343+
@test ccall((:MKL_Domain_Get_Max_Threads, libmkl_rt), Cint, (Cint,), 1) == nthreads
344+
@test ccall((:MKL_Domain_Get_Max_Threads, libmkl_rt), Cint, (Cint,), 2) != nthreads
345+
end
334346
end

0 commit comments

Comments
 (0)