2424# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525
2626
27- from nose .tools import assert_equals , nottest
27+ import unittest
28+ from unittest import skip
2829import six
2930import mkl
3031
3132
32- class test_version_information ():
33+ class test_version_information (unittest . TestCase ):
3334 def test_get_version (self ):
3435 v = mkl .get_version ()
35- assert ( isinstance ( v , dict ) )
36- assert ('MajorVersion' in v )
37- assert ('MinorVersion' in v )
38- assert ('UpdateVersion' in v )
36+ self . assertIsInstance ( v , dict )
37+ self . assertIn ('MajorVersion' , v )
38+ self . assertIn ('MinorVersion' , v )
39+ self . assertIn ('UpdateVersion' , v )
3940
4041 def test_get_version_string (self ):
4142 v = mkl .get_version_string ()
42- assert ( isinstance ( v , six .string_types ) )
43- assert ('Math Kernel Library' in v )
43+ self . assertIsInstance ( v , six .string_types )
44+ self . assertIn ('Math Kernel Library' , v )
4445
4546
46- class test_threading_control ( ):
47+ class TestThreadingControl ( unittest . TestCase ):
4748 def test_set_num_threads (self ):
4849 saved = mkl .get_max_threads ()
4950 half_nt = int ( (1 + saved ) / 2 )
5051 mkl .set_num_threads (half_nt )
51- assert (mkl .get_max_threads () == half_nt )
52+ self . assertEqual (mkl .get_max_threads (), half_nt )
5253 mkl .set_num_threads (saved )
5354
5455 def test_domain_set_num_threads_blas (self ):
@@ -60,49 +61,49 @@ def test_domain_set_num_threads_blas(self):
6061 fft_nt = int ( (3 + 2 * saved_fft_nt )/ 4 )
6162 vml_nt = int ( (3 + 3 * saved_vml_nt )/ 4 )
6263 status = mkl .domain_set_num_threads (blas_nt , domain = 'blas' )
63- assert (status == 'success' )
64+ self . assertEqual (status , 'success' )
6465 status = mkl .domain_set_num_threads (fft_nt , domain = 'fft' )
65- assert (status == 'success' )
66+ self . assertEqual (status , 'success' )
6667 status = mkl .domain_set_num_threads (vml_nt , domain = 'vml' )
67- assert (status == 'success' )
68+ self . assertEqual (status , 'success' )
6869 # check
69- assert (mkl .domain_get_max_threads (domain = 'blas' ) == blas_nt )
70- assert (mkl .domain_get_max_threads (domain = 'fft' ) == fft_nt )
71- assert (mkl .domain_get_max_threads (domain = 'vml' ) == vml_nt )
70+ self . assertEqual (mkl .domain_get_max_threads (domain = 'blas' ), blas_nt )
71+ self . assertEqual (mkl .domain_get_max_threads (domain = 'fft' ), fft_nt )
72+ self . assertEqual (mkl .domain_get_max_threads (domain = 'vml' ), vml_nt )
7273 # restore
7374 status = mkl .domain_set_num_threads (saved_blas_nt , domain = 'blas' )
74- assert (status == 'success' )
75+ self . assertEqual (status , 'success' )
7576 status = mkl .domain_set_num_threads (saved_fft_nt , domain = 'fft' )
76- assert (status == 'success' )
77+ self . assertEqual (status , 'success' )
7778 status = mkl .domain_set_num_threads (saved_vml_nt , domain = 'vml' )
78- assert (status == 'success' )
79-
79+ self . assertEqual (status , 'success' )
80+
8081 def test_domain_set_num_threads_fft (self ):
8182 status = mkl .domain_set_num_threads (4 , domain = 'fft' )
82- assert (status == 'success' )
83+ self . assertEqual (status , 'success' )
8384
8485 def test_domain_set_num_threads_vml (self ):
8586 status = mkl .domain_set_num_threads (4 , domain = 'vml' )
86- assert (status == 'success' )
87+ self . assertEqual (status , 'success' )
8788
8889 def test_domain_set_num_threads_pardiso (self ):
8990 status = mkl .domain_set_num_threads (4 , domain = 'pardiso' )
90- assert (status == 'success' )
91+ self . assertEqual (status , 'success' )
9192
9293 def test_domain_set_num_threads_all (self ):
9394 status = mkl .domain_set_num_threads (4 , domain = 'all' )
94- assert (status == 'success' )
95+ self . assertEqual (status , 'success' )
9596
9697 def test_set_num_threads_local (self ):
9798 mkl .set_num_threads (1 )
9899 status = mkl .set_num_threads_local (2 )
99- assert (status == 'global_num_threads' )
100+ self . assertEqual (status , 'global_num_threads' )
100101 status = mkl .set_num_threads_local (4 )
101- assert (status == 2 )
102+ self . assertEqual (status , 2 )
102103 status = mkl .set_num_threads_local (0 )
103- assert (status == 4 )
104+ self . assertEqual (status , 4 )
104105 status = mkl .set_num_threads_local (8 )
105- assert (status == 'global_num_threads' )
106+ self . assertEqual (status , 'global_num_threads' )
106107
107108 def test_set_dynamic (self ):
108109 mkl .set_dynamic (True )
@@ -129,37 +130,37 @@ def test_get_dynamic(self):
129130 mkl .get_dynamic ()
130131
131132
132- class test_timing ():
133+ class test_timing (unittest . TestCase ):
133134 # https://software.intel.com/en-us/mkl-developer-reference-c-timing
134135 def test_second (self ):
135136 s1 = mkl .second ()
136137 s2 = mkl .second ()
137138 delta = s2 - s1
138- assert (delta >= 0 )
139+ self . assertGreaterEqual (delta , 0 )
139140
140141 def test_dsecnd (self ):
141142 d1 = mkl .dsecnd ()
142143 d2 = mkl .dsecnd ()
143144 delta = d2 - d1
144- assert (delta >= 0 )
145+ self . assertGreaterEqual (delta , 0 )
145146
146147 def test_get_cpu_clocks (self ):
147148 c1 = mkl .get_cpu_clocks ()
148149 c2 = mkl .get_cpu_clocks ()
149150 delta = c2 - c1
150- assert (delta >= 0 )
151+ self . assertGreaterEqual (delta , 0 )
151152
152153 def test_get_cpu_frequency (self ):
153- assert (mkl .get_cpu_frequency () > 0 )
154+ self . assertGreater (mkl .get_cpu_frequency (), 0 )
154155
155156 def test_get_max_cpu_frequency (self ):
156- assert (mkl .get_max_cpu_frequency () > 0 )
157+ self . assertGreater (mkl .get_max_cpu_frequency (), 0 )
157158
158159 def test_get_clocks_frequency (self ):
159- assert (mkl .get_clocks_frequency () > 0 )
160+ self . assertGreater (mkl .get_clocks_frequency (), 0 )
160161
161162
162- class test_memory_management ():
163+ class test_memory_management (unittest . TestCase ):
163164 def test_free_buffers (self ):
164165 mkl .free_buffers ()
165166
@@ -188,7 +189,7 @@ def test_set_memory_limit(self):
188189 mkl .set_memory_limit (128 )
189190
190191
191- class test_cnr_control ( ):
192+ class TestCNRControl ( unittest . TestCase ):
192193 def test_cbwr (self ):
193194 branches = [
194195 'off' ,
@@ -213,26 +214,28 @@ def test_cbwr(self):
213214 'avx512_e1,strict' ,
214215 ]
215216 for branch in branches :
216- yield self .check_cbwr , branch , 'branch'
217+ with self .subTest (branch = branch ):
218+ self .check_cbwr (branch , 'branch' )
217219 for branch in branches + strict :
218- yield self .check_cbwr , branch , 'all'
220+ with self .subTest (branch = branch ):
221+ self .check_cbwr (branch , 'all' )
219222
220223 def check_cbwr (self , branch , cnr_const ):
221224 status = mkl .cbwr_set (branch = branch )
222225 if status == 'success' :
223226 expected_value = 'branch_off' if branch == 'off' else branch
224227 actual_value = mkl .cbwr_get (cnr_const = cnr_const )
225- assert_equals (actual_value ,
226- expected_value ,
227- msg = "Round-trip failure for CNR branch '{}', CNR const '{}'" .format (branch , cnr_const ))
228+ self . assertEqual (actual_value ,
229+ expected_value ,
230+ "Round-trip failure for CNR branch '{}', CNR const '{}'" .format (branch , cnr_const ))
228231 elif status != 'err_unsupported_branch' :
229232 raise AssertionError (status )
230233
231234 def test_cbwr_get_auto_branch (self ):
232235 mkl .cbwr_get_auto_branch ()
233236
234237
235- class test_miscellaneous ():
238+ class test_miscellaneous (unittest . TestCase ):
236239 def test_enable_instructions_avx512_mic_e1 (self ):
237240 mkl .enable_instructions ('avx512_mic_e1' )
238241
@@ -266,7 +269,7 @@ def test_verbose_true(self):
266269 #def test_set_mpi_custom(self):
267270 # mkl.set_mpi('custom', 'custom_library_name')
268271
269- @nottest
272+ @skip
270273 def test_set_mpi_msmpi (self ):
271274 mkl .set_mpi ('msmpi' )
272275
@@ -277,7 +280,7 @@ def test_set_mpi_msmpi(self):
277280 # mkl.set_mpi('mpich2')
278281
279282
280- class test_vm_service_functions ():
283+ class test_vm_service_functions (unittest . TestCase ):
281284 def test_vml_set_get_mode_roundtrip (self ):
282285 saved = mkl .vml_get_mode ()
283286 mkl .vml_set_mode (* saved ) # should not raise errors
@@ -332,3 +335,6 @@ def test_vml_get_err_status(self):
332335
333336 def test_vml_clear_err_status (self ):
334337 mkl .vml_clear_err_status ()
338+
339+ if __name__ == '__main__' :
340+ unittest .main ()
0 commit comments