24
24
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
25
26
26
27
- from nose .tools import assert_equals , nottest
27
+ import unittest
28
+ from unittest import skip
28
29
import six
29
30
import mkl
30
31
31
32
32
- class test_version_information ():
33
+ class test_version_information (unittest . TestCase ):
33
34
def test_get_version (self ):
34
35
v = mkl .get_version ()
35
- assert ( isinstance ( v , dict ) )
36
- assert ('MajorVersion' in v )
37
- assert ('MinorVersion' in v )
38
- assert ('UpdateVersion' in v )
36
+ self . assertIsInstance ( v , dict )
37
+ self . assertIn ('MajorVersion' , v )
38
+ self . assertIn ('MinorVersion' , v )
39
+ self . assertIn ('UpdateVersion' , v )
39
40
40
41
def test_get_version_string (self ):
41
42
v = mkl .get_version_string ()
42
- assert ( isinstance ( v , six .string_types ) )
43
- assert ('Math Kernel Library' in v )
43
+ self . assertIsInstance ( v , six .string_types )
44
+ self . assertIn ('Math Kernel Library' , v )
44
45
45
46
46
- class test_threading_control ( ):
47
+ class TestThreadingControl ( unittest . TestCase ):
47
48
def test_set_num_threads (self ):
48
49
saved = mkl .get_max_threads ()
49
50
half_nt = int ( (1 + saved ) / 2 )
50
51
mkl .set_num_threads (half_nt )
51
- assert (mkl .get_max_threads () == half_nt )
52
+ self . assertEqual (mkl .get_max_threads (), half_nt )
52
53
mkl .set_num_threads (saved )
53
54
54
55
def test_domain_set_num_threads_blas (self ):
@@ -60,49 +61,49 @@ def test_domain_set_num_threads_blas(self):
60
61
fft_nt = int ( (3 + 2 * saved_fft_nt )/ 4 )
61
62
vml_nt = int ( (3 + 3 * saved_vml_nt )/ 4 )
62
63
status = mkl .domain_set_num_threads (blas_nt , domain = 'blas' )
63
- assert (status == 'success' )
64
+ self . assertEqual (status , 'success' )
64
65
status = mkl .domain_set_num_threads (fft_nt , domain = 'fft' )
65
- assert (status == 'success' )
66
+ self . assertEqual (status , 'success' )
66
67
status = mkl .domain_set_num_threads (vml_nt , domain = 'vml' )
67
- assert (status == 'success' )
68
+ self . assertEqual (status , 'success' )
68
69
# check
69
- assert (mkl .domain_get_max_threads (domain = 'blas' ) == blas_nt )
70
- assert (mkl .domain_get_max_threads (domain = 'fft' ) == fft_nt )
71
- assert (mkl .domain_get_max_threads (domain = 'vml' ) == vml_nt )
70
+ self . assertEqual (mkl .domain_get_max_threads (domain = 'blas' ), blas_nt )
71
+ self . assertEqual (mkl .domain_get_max_threads (domain = 'fft' ), fft_nt )
72
+ self . assertEqual (mkl .domain_get_max_threads (domain = 'vml' ), vml_nt )
72
73
# restore
73
74
status = mkl .domain_set_num_threads (saved_blas_nt , domain = 'blas' )
74
- assert (status == 'success' )
75
+ self . assertEqual (status , 'success' )
75
76
status = mkl .domain_set_num_threads (saved_fft_nt , domain = 'fft' )
76
- assert (status == 'success' )
77
+ self . assertEqual (status , 'success' )
77
78
status = mkl .domain_set_num_threads (saved_vml_nt , domain = 'vml' )
78
- assert (status == 'success' )
79
-
79
+ self . assertEqual (status , 'success' )
80
+
80
81
def test_domain_set_num_threads_fft (self ):
81
82
status = mkl .domain_set_num_threads (4 , domain = 'fft' )
82
- assert (status == 'success' )
83
+ self . assertEqual (status , 'success' )
83
84
84
85
def test_domain_set_num_threads_vml (self ):
85
86
status = mkl .domain_set_num_threads (4 , domain = 'vml' )
86
- assert (status == 'success' )
87
+ self . assertEqual (status , 'success' )
87
88
88
89
def test_domain_set_num_threads_pardiso (self ):
89
90
status = mkl .domain_set_num_threads (4 , domain = 'pardiso' )
90
- assert (status == 'success' )
91
+ self . assertEqual (status , 'success' )
91
92
92
93
def test_domain_set_num_threads_all (self ):
93
94
status = mkl .domain_set_num_threads (4 , domain = 'all' )
94
- assert (status == 'success' )
95
+ self . assertEqual (status , 'success' )
95
96
96
97
def test_set_num_threads_local (self ):
97
98
mkl .set_num_threads (1 )
98
99
status = mkl .set_num_threads_local (2 )
99
- assert (status == 'global_num_threads' )
100
+ self . assertEqual (status , 'global_num_threads' )
100
101
status = mkl .set_num_threads_local (4 )
101
- assert (status == 2 )
102
+ self . assertEqual (status , 2 )
102
103
status = mkl .set_num_threads_local (0 )
103
- assert (status == 4 )
104
+ self . assertEqual (status , 4 )
104
105
status = mkl .set_num_threads_local (8 )
105
- assert (status == 'global_num_threads' )
106
+ self . assertEqual (status , 'global_num_threads' )
106
107
107
108
def test_set_dynamic (self ):
108
109
mkl .set_dynamic (True )
@@ -129,37 +130,37 @@ def test_get_dynamic(self):
129
130
mkl .get_dynamic ()
130
131
131
132
132
- class test_timing ():
133
+ class test_timing (unittest . TestCase ):
133
134
# https://software.intel.com/en-us/mkl-developer-reference-c-timing
134
135
def test_second (self ):
135
136
s1 = mkl .second ()
136
137
s2 = mkl .second ()
137
138
delta = s2 - s1
138
- assert (delta >= 0 )
139
+ self . assertGreaterEqual (delta , 0 )
139
140
140
141
def test_dsecnd (self ):
141
142
d1 = mkl .dsecnd ()
142
143
d2 = mkl .dsecnd ()
143
144
delta = d2 - d1
144
- assert (delta >= 0 )
145
+ self . assertGreaterEqual (delta , 0 )
145
146
146
147
def test_get_cpu_clocks (self ):
147
148
c1 = mkl .get_cpu_clocks ()
148
149
c2 = mkl .get_cpu_clocks ()
149
150
delta = c2 - c1
150
- assert (delta >= 0 )
151
+ self . assertGreaterEqual (delta , 0 )
151
152
152
153
def test_get_cpu_frequency (self ):
153
- assert (mkl .get_cpu_frequency () > 0 )
154
+ self . assertGreater (mkl .get_cpu_frequency (), 0 )
154
155
155
156
def test_get_max_cpu_frequency (self ):
156
- assert (mkl .get_max_cpu_frequency () > 0 )
157
+ self . assertGreater (mkl .get_max_cpu_frequency (), 0 )
157
158
158
159
def test_get_clocks_frequency (self ):
159
- assert (mkl .get_clocks_frequency () > 0 )
160
+ self . assertGreater (mkl .get_clocks_frequency (), 0 )
160
161
161
162
162
- class test_memory_management ():
163
+ class test_memory_management (unittest . TestCase ):
163
164
def test_free_buffers (self ):
164
165
mkl .free_buffers ()
165
166
@@ -188,7 +189,7 @@ def test_set_memory_limit(self):
188
189
mkl .set_memory_limit (128 )
189
190
190
191
191
- class test_cnr_control ( ):
192
+ class TestCNRControl ( unittest . TestCase ):
192
193
def test_cbwr (self ):
193
194
branches = [
194
195
'off' ,
@@ -213,26 +214,28 @@ def test_cbwr(self):
213
214
'avx512_e1,strict' ,
214
215
]
215
216
for branch in branches :
216
- yield self .check_cbwr , branch , 'branch'
217
+ with self .subTest (branch = branch ):
218
+ self .check_cbwr (branch , 'branch' )
217
219
for branch in branches + strict :
218
- yield self .check_cbwr , branch , 'all'
220
+ with self .subTest (branch = branch ):
221
+ self .check_cbwr (branch , 'all' )
219
222
220
223
def check_cbwr (self , branch , cnr_const ):
221
224
status = mkl .cbwr_set (branch = branch )
222
225
if status == 'success' :
223
226
expected_value = 'branch_off' if branch == 'off' else branch
224
227
actual_value = mkl .cbwr_get (cnr_const = cnr_const )
225
- assert_equals (actual_value ,
226
- expected_value ,
227
- msg = "Round-trip failure for CNR branch '{}', CNR const '{}'" .format (branch , cnr_const ))
228
+ self . assertEqual (actual_value ,
229
+ expected_value ,
230
+ "Round-trip failure for CNR branch '{}', CNR const '{}'" .format (branch , cnr_const ))
228
231
elif status != 'err_unsupported_branch' :
229
232
raise AssertionError (status )
230
233
231
234
def test_cbwr_get_auto_branch (self ):
232
235
mkl .cbwr_get_auto_branch ()
233
236
234
237
235
- class test_miscellaneous ():
238
+ class test_miscellaneous (unittest . TestCase ):
236
239
def test_enable_instructions_avx512_mic_e1 (self ):
237
240
mkl .enable_instructions ('avx512_mic_e1' )
238
241
@@ -266,7 +269,7 @@ def test_verbose_true(self):
266
269
#def test_set_mpi_custom(self):
267
270
# mkl.set_mpi('custom', 'custom_library_name')
268
271
269
- @nottest
272
+ @skip
270
273
def test_set_mpi_msmpi (self ):
271
274
mkl .set_mpi ('msmpi' )
272
275
@@ -277,7 +280,7 @@ def test_set_mpi_msmpi(self):
277
280
# mkl.set_mpi('mpich2')
278
281
279
282
280
- class test_vm_service_functions ():
283
+ class test_vm_service_functions (unittest . TestCase ):
281
284
def test_vml_set_get_mode_roundtrip (self ):
282
285
saved = mkl .vml_get_mode ()
283
286
mkl .vml_set_mode (* saved ) # should not raise errors
@@ -332,3 +335,6 @@ def test_vml_get_err_status(self):
332
335
333
336
def test_vml_clear_err_status (self ):
334
337
mkl .vml_clear_err_status ()
338
+
339
+ if __name__ == '__main__' :
340
+ unittest .main ()
0 commit comments