@@ -48,25 +48,29 @@ class test_threading_control():
48
48
# https://software.intel.com/en-us/mkl-developer-reference-c-threading-control
49
49
def test_set_num_threads (self ):
50
50
saved = mkl .get_max_threads ()
51
- mkl .set_num_threads (8 )
52
- assert (mkl .get_max_threads () == 8 )
51
+ half_nt = int ( 0.5 + saved / 2 )
52
+ mkl .set_num_threads (half_nt )
53
+ assert (mkl .get_max_threads () == half_nt )
53
54
mkl .set_num_threads (saved )
54
55
55
56
def test_domain_set_num_threads_blas (self ):
56
57
saved_blas_nt = mkl .domain_get_max_threads (domain = 'blas' )
57
58
saved_fft_nt = mkl .domain_get_max_threads (domain = 'fft' )
58
59
saved_vml_nt = mkl .domain_get_max_threads (domain = 'vml' )
59
60
# set
60
- status = mkl .domain_set_num_threads (4 , domain = 'blas' )
61
+ blas_nt = int ( (3 + saved_blas_nt )/ 4 )
62
+ fft_nt = int ( (3 + 2 * saved_fft_nt )/ 4 )
63
+ vml_nt = int ( (3 + 3 * saved_vml_nt )/ 4 )
64
+ status = mkl .domain_set_num_threads (blas_nt , domain = 'blas' )
61
65
assert (status == 'success' )
62
- status = mkl .domain_set_num_threads (5 , domain = 'fft' )
66
+ status = mkl .domain_set_num_threads (fft_nt , domain = 'fft' )
63
67
assert (status == 'success' )
64
- status = mkl .domain_set_num_threads (6 , domain = 'vml' )
68
+ status = mkl .domain_set_num_threads (vml_nt , domain = 'vml' )
65
69
assert (status == 'success' )
66
70
# check
67
- assert (mkl .domain_get_max_threads (domain = 'blas' ) == 4 )
68
- assert (mkl .domain_get_max_threads (domain = 'fft' ) == 5 )
69
- assert (mkl .domain_get_max_threads (domain = 'vml' ) == 6 )
71
+ assert (mkl .domain_get_max_threads (domain = 'blas' ) == blas_nt )
72
+ assert (mkl .domain_get_max_threads (domain = 'fft' ) == fft_nt )
73
+ assert (mkl .domain_get_max_threads (domain = 'vml' ) == vml_nt )
70
74
# restore
71
75
status = mkl .domain_set_num_threads (saved_blas_nt , domain = 'blas' )
72
76
assert (status == 'success' )
0 commit comments