Skip to content

Commit 8143077

Browse files
Updated Threading Control and CBWR functions to Python style.
1 parent 245fcd8 commit 8143077

File tree

2 files changed

+220
-95
lines changed

2 files changed

+220
-95
lines changed

mkl-service/_mkl_service.pyx

Lines changed: 212 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,112 @@
11
cimport _mkl_service as mkl
2-
3-
4-
# MKL_INT64 mkl_peak_mem_usage(int mode)
5-
# In
6-
MKL_PEAK_MEM_ENABLE = mkl.MKL_PEAK_MEM_ENABLE
7-
MKL_PEAK_MEM_DISABLE = mkl.MKL_PEAK_MEM_DISABLE
8-
MKL_PEAK_MEM = mkl.MKL_PEAK_MEM
9-
MKL_PEAK_MEM_RESET = mkl.MKL_PEAK_MEM_RESET
10-
11-
# int mkl_set_memory_limit(int mem_type, size_t limit)
12-
# In
13-
MKL_MEM_MCDRAM = mkl.MKL_MEM_MCDRAM
14-
15-
# int mkl_cbwr_set(int settings)
16-
# In
17-
MKL_CBWR_AUTO = mkl.MKL_CBWR_AUTO
18-
MKL_CBWR_COMPATIBLE = mkl.MKL_CBWR_COMPATIBLE
19-
MKL_CBWR_SSE2 = mkl.MKL_CBWR_SSE2
20-
MKL_CBWR_SSE3 = mkl.MKL_CBWR_SSE3
21-
MKL_CBWR_SSSE3 = mkl.MKL_CBWR_SSSE3
22-
MKL_CBWR_SSE4_1 = mkl.MKL_CBWR_SSE4_1
23-
MKL_CBWR_SSE4_2 = mkl.MKL_CBWR_SSE4_2
24-
MKL_CBWR_AVX = mkl.MKL_CBWR_AVX
25-
MKL_CBWR_AVX2 = mkl.MKL_CBWR_AVX2
26-
MKL_CBWR_AVX512_MIC = mkl.MKL_CBWR_AVX512_MIC
27-
MKL_CBWR_AVX512 = mkl.MKL_CBWR_AVX512_MIC
28-
# Out
29-
MKL_CBWR_SUCCESS = mkl.MKL_CBWR_SUCCESS
30-
MKL_CBWR_ERR_INVALID_INPUT = mkl.MKL_CBWR_ERR_INVALID_INPUT
31-
MKL_CBWR_ERR_UNSUPPORTED_BRANCH = mkl.MKL_CBWR_ERR_UNSUPPORTED_BRANCH
32-
MKL_CBWR_ERR_MODE_CHANGE_FAILURE = mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE
33-
34-
# int mkl_cbwr_get(int option)
35-
# In
36-
MKL_CBWR_BRANCH = mkl.MKL_CBWR_BRANCH
37-
MKL_CBWR_ALL = mkl.MKL_CBWR_ALL
38-
39-
# int mkl_enable_instructions(int isa)
40-
# In
41-
MKL_ENABLE_AVX512 = mkl.MKL_ENABLE_AVX512
42-
MKL_ENABLE_AVX512_MIC = mkl.MKL_ENABLE_AVX512_MIC
43-
MKL_ENABLE_AVX2 = mkl.MKL_ENABLE_AVX2
44-
MKL_ENABLE_AVX = mkl.MKL_ENABLE_AVX
45-
MKL_ENABLE_SSE4_2 = mkl.MKL_ENABLE_SSE4_2
46-
47-
# unsigned int vmlSetMode(unsigned int mode)
48-
# In
49-
VML_HA = mkl.VML_HA
50-
VML_LA = mkl.VML_LA
51-
VML_EP = mkl.VML_EP
52-
VML_FTZDAZ_ON = mkl.VML_FTZDAZ_ON
53-
VML_FTZDAZ_OFF = mkl.VML_FTZDAZ_OFF
54-
VML_ERRMODE_IGNORE = mkl.VML_ERRMODE_IGNORE
55-
VML_ERRMODE_ERRNO = mkl.VML_ERRMODE_ERRNO
56-
VML_ERRMODE_STDERR = mkl.VML_ERRMODE_STDERR
57-
VML_ERRMODE_EXCEPT = mkl.VML_ERRMODE_EXCEPT
58-
VML_ERRMODE_CALLBACK = mkl.VML_ERRMODE_CALLBACK
59-
VML_ERRMODE_DEFAULT = mkl.VML_ERRMODE_DEFAULT
60-
61-
# int vmlSetErrStatus(const MKL_INT status)
62-
# In
63-
VML_STATUS_OK = mkl.VML_STATUS_OK
64-
VML_STATUS_ACCURACYWARNING = mkl.VML_STATUS_ACCURACYWARNING
65-
VML_STATUS_BADSIZE = mkl.VML_STATUS_BADSIZE
66-
VML_STATUS_BADMEM = mkl.VML_STATUS_BADMEM
67-
VML_STATUS_ERRDOM = mkl.VML_STATUS_ERRDOM
68-
VML_STATUS_SING = mkl.VML_STATUS_SING
69-
VML_STATUS_OVERFLOW = mkl.VML_STATUS_OVERFLOW
70-
VML_STATUS_UNDERFLOW = mkl.VML_STATUS_UNDERFLOW
2+
from enum import IntEnum
3+
4+
5+
class enums(IntEnum):
6+
# MKL Function Domains
7+
MKL_DOMAIN_BLAS = mkl.MKL_DOMAIN_BLAS
8+
MKL_DOMAIN_FFT = mkl.MKL_DOMAIN_FFT
9+
MKL_DOMAIN_VML = mkl.MKL_DOMAIN_VML
10+
MKL_DOMAIN_PARDISO = mkl.MKL_DOMAIN_PARDISO
11+
MKL_DOMAIN_ALL = mkl.MKL_DOMAIN_ALL
12+
13+
# MKL_INT64 mkl_peak_mem_usage(int mode)
14+
# In
15+
MKL_PEAK_MEM_ENABLE = mkl.MKL_PEAK_MEM_ENABLE
16+
MKL_PEAK_MEM_DISABLE = mkl.MKL_PEAK_MEM_DISABLE
17+
MKL_PEAK_MEM = mkl.MKL_PEAK_MEM
18+
MKL_PEAK_MEM_RESET = mkl.MKL_PEAK_MEM_RESET
19+
20+
# int mkl_set_memory_limit(int mem_type, size_t limit)
21+
# In
22+
MKL_MEM_MCDRAM = mkl.MKL_MEM_MCDRAM
23+
24+
# CNR Control Constants
25+
MKL_CBWR_AUTO = mkl.MKL_CBWR_AUTO
26+
MKL_CBWR_COMPATIBLE = mkl.MKL_CBWR_COMPATIBLE
27+
MKL_CBWR_SSE2 = mkl.MKL_CBWR_SSE2
28+
MKL_CBWR_SSE3 = mkl.MKL_CBWR_SSE3
29+
MKL_CBWR_SSSE3 = mkl.MKL_CBWR_SSSE3
30+
MKL_CBWR_SSE4_1 = mkl.MKL_CBWR_SSE4_1
31+
MKL_CBWR_SSE4_2 = mkl.MKL_CBWR_SSE4_2
32+
MKL_CBWR_AVX = mkl.MKL_CBWR_AVX
33+
MKL_CBWR_AVX2 = mkl.MKL_CBWR_AVX2
34+
MKL_CBWR_AVX512_MIC = mkl.MKL_CBWR_AVX512_MIC
35+
MKL_CBWR_AVX512 = mkl.MKL_CBWR_AVX512_MIC
36+
MKL_CBWR_BRANCH = mkl.MKL_CBWR_BRANCH
37+
MKL_CBWR_ALL = mkl.MKL_CBWR_ALL
38+
MKL_CBWR_SUCCESS = mkl.MKL_CBWR_SUCCESS
39+
MKL_CBWR_ERR_INVALID_INPUT = mkl.MKL_CBWR_ERR_INVALID_INPUT
40+
MKL_CBWR_ERR_UNSUPPORTED_BRANCH = mkl.MKL_CBWR_ERR_UNSUPPORTED_BRANCH
41+
MKL_CBWR_ERR_MODE_CHANGE_FAILURE = mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE
42+
43+
# int mkl_enable_instructions(int isa)
44+
# In
45+
MKL_ENABLE_AVX512 = mkl.MKL_ENABLE_AVX512
46+
MKL_ENABLE_AVX512_MIC = mkl.MKL_ENABLE_AVX512_MIC
47+
MKL_ENABLE_AVX2 = mkl.MKL_ENABLE_AVX2
48+
MKL_ENABLE_AVX = mkl.MKL_ENABLE_AVX
49+
MKL_ENABLE_SSE4_2 = mkl.MKL_ENABLE_SSE4_2
50+
51+
# unsigned int vmlSetMode(unsigned int mode)
52+
# In
53+
VML_HA = mkl.VML_HA
54+
VML_LA = mkl.VML_LA
55+
VML_EP = mkl.VML_EP
56+
VML_FTZDAZ_ON = mkl.VML_FTZDAZ_ON
57+
VML_FTZDAZ_OFF = mkl.VML_FTZDAZ_OFF
58+
VML_ERRMODE_IGNORE = mkl.VML_ERRMODE_IGNORE
59+
VML_ERRMODE_ERRNO = mkl.VML_ERRMODE_ERRNO
60+
VML_ERRMODE_STDERR = mkl.VML_ERRMODE_STDERR
61+
VML_ERRMODE_EXCEPT = mkl.VML_ERRMODE_EXCEPT
62+
VML_ERRMODE_CALLBACK = mkl.VML_ERRMODE_CALLBACK
63+
VML_ERRMODE_DEFAULT = mkl.VML_ERRMODE_DEFAULT
64+
65+
# int vmlSetErrStatus(const MKL_INT status)
66+
# In
67+
VML_STATUS_OK = mkl.VML_STATUS_OK
68+
VML_STATUS_ACCURACYWARNING = mkl.VML_STATUS_ACCURACYWARNING
69+
VML_STATUS_BADSIZE = mkl.VML_STATUS_BADSIZE
70+
VML_STATUS_BADMEM = mkl.VML_STATUS_BADMEM
71+
VML_STATUS_ERRDOM = mkl.VML_STATUS_ERRDOM
72+
VML_STATUS_SING = mkl.VML_STATUS_SING
73+
VML_STATUS_OVERFLOW = mkl.VML_STATUS_OVERFLOW
74+
VML_STATUS_UNDERFLOW = mkl.VML_STATUS_UNDERFLOW
75+
76+
# MKL Function Domains
77+
__mkl_domain_enums = {'blas': mkl.MKL_DOMAIN_BLAS,
78+
'fft': mkl.MKL_DOMAIN_FFT,
79+
'vml': mkl.MKL_DOMAIN_VML,
80+
'pardiso': mkl.MKL_DOMAIN_PARDISO,
81+
'all': mkl.MKL_DOMAIN_ALL}
82+
83+
# CNR Control Constants
84+
__mkl_cbwr_set_in_enums = {'auto': mkl.MKL_CBWR_AUTO,
85+
'compatible': mkl.MKL_CBWR_COMPATIBLE,
86+
'sse2': mkl.MKL_CBWR_SSE2,
87+
'sse3': mkl.MKL_CBWR_SSE3,
88+
'ssse3': mkl.MKL_CBWR_SSSE3,
89+
'sse4_1': mkl.MKL_CBWR_SSE4_1,
90+
'sse4_2': mkl.MKL_CBWR_SSE4_2,
91+
'avx': mkl.MKL_CBWR_AVX,
92+
'avx2': mkl.MKL_CBWR_AVX2,
93+
'avx512_mic': mkl.MKL_CBWR_AVX512_MIC,
94+
'avx512': mkl.MKL_CBWR_AVX512}
95+
96+
__mkl_cbwr_set_out_enums = {mkl.MKL_CBWR_SUCCESS: 'success',
97+
mkl.MKL_CBWR_ERR_INVALID_INPUT: 'err_invalid_input',
98+
mkl.MKL_CBWR_ERR_UNSUPPORTED_BRANCH: 'err_unsupported_branch',
99+
mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE: 'err_mode_change_failure'}
100+
101+
__mkl_cbwr_get_in_enums = {'branch': mkl.MKL_CBWR_BRANCH,
102+
'all': mkl.MKL_CBWR_ALL}
103+
104+
__mkl_cbwr_get_out_enums = {mkl.MKL_CBWR_SUCCESS: 'success',
105+
mkl.MKL_CBWR_ERR_INVALID_INPUT: 'err_invalid_input'}
106+
__mkl_cbwr_get_out_enums.update({value: key for key, value in __mkl_cbwr_set_in_enums.items()})
107+
108+
__mkl_cbwr_get_auto_branch_out_enums = {}
109+
__mkl_cbwr_get_auto_branch_out_enums.update({value: key for key, value in __mkl_cbwr_set_in_enums.items()})
71110

72111

73112
'''
@@ -98,35 +137,92 @@ def mkl_get_version_string():
98137
int mkl_domain_get_max_threads(int domain)
99138
int mkl_get_dynamic()
100139
'''
101-
def mkl_set_num_threads(nt):
102-
mkl.mkl_set_num_threads(nt)
140+
def mkl_set_num_threads(num_threads):
141+
assert(type(num_threads) is int)
142+
assert(num_threads > 0)
143+
144+
prev_num_threads = mkl_get_max_threads()
145+
assert(type(prev_num_threads) is int)
146+
assert(prev_num_threads > 0)
147+
148+
mkl.mkl_set_num_threads(num_threads)
149+
150+
return prev_num_threads
151+
103152

153+
def mkl_domain_set_num_threads(num_threads, domain='all'):
154+
assert(type(num_threads) is int)
155+
assert(num_threads >= 0)
156+
domain_type = type(domain)
157+
if domain_type is str:
158+
assert(domain in __mkl_domain_enums.keys())
159+
domain = __mkl_domain_enums[domain]
160+
else:
161+
assert((domain_type is int) and (domain in __mkl_domain_enums.values()))
104162

105-
def mkl_domain_set_num_threads(nth, domain):
106-
return mkl.mkl_domain_set_num_threads(nth, domain)
163+
status = mkl.mkl_domain_set_num_threads(num_threads, domain)
164+
assert((status == 0) or (status == 1))
107165

166+
if (status == 1):
167+
status = 'success'
168+
else:
169+
status = 'error'
108170

109-
def mkl_set_num_threads_local(nth):
110-
return mkl.mkl_set_num_threads_local(nth)
171+
return status
172+
173+
174+
def mkl_set_num_threads_local(num_threads):
175+
assert(type(num_threads) is int)
176+
assert(num_threads >= 0)
177+
status = mkl.mkl_set_num_threads_local(num_threads)
178+
assert(status >= 0)
179+
180+
if (status == 0):
181+
status = 'global_num_threads'
182+
183+
return status
184+
185+
def mkl_set_dynamic(enable):
186+
assert(type(enable) is bool)
187+
if enable:
188+
enable = 1
189+
else:
190+
enable = 0
111191

192+
mkl.mkl_set_dynamic(enable)
112193

113-
def mkl_set_dynamic(flag):
114-
mkl.mkl_set_dynamic(flag)
194+
return mkl_get_max_threads()
115195

116196

117197
def mkl_get_max_threads():
118-
return mkl.mkl_get_max_threads()
198+
num_threads = mkl.mkl_get_max_threads()
199+
assert(type(num_threads) is int)
200+
assert(num_threads >= 1)
119201

202+
return num_threads
120203

121-
def mkl_domain_get_max_threads(domain):
122-
return mkl.mkl_domain_get_max_threads(domain)
204+
205+
def mkl_domain_get_max_threads(domain='all'):
206+
domain_type = type(domain)
207+
if domain_type is str:
208+
assert(domain in __mkl_domain_enums.keys())
209+
domain = __mkl_domain_enums[domain]
210+
else:
211+
assert((domain_type is int) and (domain in __mkl_domain_enums.values()))
212+
213+
num_threads = mkl.mkl_domain_get_max_threads(domain)
214+
assert(type(num_threads) is int)
215+
assert(num_threads >= 1)
216+
217+
return num_threads
123218

124219

125220
def mkl_get_dynamic():
221+
dynamic_enabled = mkl.mkl_get_dynamic()
222+
assert((dynamic_enabled == 0) or (dynamic_enabled == 1))
126223
return mkl.mkl_get_dynamic()
127224

128225

129-
130226
'''
131227
# Timing
132228
float second()
@@ -200,19 +296,47 @@ def mkl_set_memory_limit(mem_type, limit):
200296

201297

202298
'''
203-
#Conditional Numerical Reproducibility
299+
# Conditional Numerical Reproducibility
204300
int mkl_cbwr_set(int settings)
205-
int mkl_cbwr_get()
301+
int mkl_cbwr_get(int option)
206302
int mkl_cbwr_get_auto_branch()
207303
'''
208-
def mkl_cbwr_set(settings):
209-
return mkl.mkl_cbwr_set(settings)
304+
def mkl_cbwr_set(branch=''):
305+
branch_type = type(branch)
306+
if branch_type is str:
307+
assert(branch in __mkl_cbwr_set_in_enums.keys())
308+
branch = __mkl_cbwr_set_in_enums[branch]
309+
else:
310+
assert((branch_type is int) and (branch in __mkl_cbwr_set_in_enums.values()))
311+
312+
status = mkl.mkl_cbwr_set(branch)
313+
assert(status in __mkl_cbwr_set_out_enums.keys())
314+
315+
return __mkl_cbwr_set_out_enums[status]
316+
317+
318+
def mkl_cbwr_get(cnr_const=''):
319+
cnr_const_type = type(cnr_const)
320+
if cnr_const_type is str:
321+
assert(cnr_const in __mkl_cbwr_get_in_enums)
322+
cnr_const = __mkl_cbwr_get_in_enums[cnr_const]
323+
else:
324+
assert(issubclass(cnr_const_type, IntEnum))
325+
assert(type(cnr_const.value) is int)
326+
assert(cnr_const.value in __mkl_cbwr_get_in_enums.values())
327+
cnr_const = cnr_const.value
328+
329+
status = mkl.mkl_cbwr_get(cnr_const)
330+
assert(status in __mkl_cbwr_get_out_enums)
331+
332+
return __mkl_cbwr_get_out_enums[status]
210333

211-
def mkl_cbwr_get(option):
212-
return mkl.mkl_cbwr_get(option)
213334

214335
def mkl_cbwr_get_auto_branch():
215-
return mkl.mkl_cbwr_get_auto_branch()
336+
status = mkl.mkl_cbwr_get_auto_branch()
337+
assert(status in __mkl_cbwr_get_auto_branch_out_enums)
338+
339+
return __mkl_cbwr_get_auto_branch_out_enums[status]
216340

217341

218342
'''

mkl-service/tests/test_mkl_service.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def test_mkl_set_num_threads(self):
1616
mkl_service.mkl_set_num_threads(1)
1717

1818
def test_mkl_domain_set_num_threads(self):
19-
mkl_service.mkl_domain_set_num_threads(1, 1)
19+
mkl_service.mkl_domain_set_num_threads(1, domain='fft')
2020

2121
def test_mkl_set_num_threads_local(self):
2222
mkl_service.mkl_set_num_threads_local(1)
2323

2424
def test_mkl_set_dynamic(self):
25-
mkl_service.mkl_set_dynamic(0)
25+
mkl_service.mkl_set_dynamic(True)
2626

2727
def test_mkl_get_max_threads(self):
2828
mkl_service.mkl_get_max_threads()
@@ -88,10 +88,11 @@ def test_mkl_set_memory_limit(self):
8888
class test_conditional_numerical_reproducibility_control:
8989
# https://software.intel.com/en-us/mkl-developer-reference-c-conditional-numerical-reproducibility-control
9090
def test_mkl_cbwr_set(self):
91-
mkl_service.mkl_cbwr_set(mkl_service.MKL_CBWR_AUTO)
91+
#mkl_service.mkl_cbwr_set(mkl_service.MKL_CBWR_AUTO)
92+
mkl_service.mkl_cbwr_set(branch='auto')
9293

9394
def test_mkl_cbwr_get(self):
94-
mkl_service.mkl_cbwr_get(mkl_service.MKL_CBWR_ALL)
95+
mkl_service.mkl_cbwr_get(cnr_const=mkl_service.enums.MKL_CBWR_ALL)
9596

9697
def test_mkl_cbwr_get_auto_branch(self):
9798
mkl_service.mkl_cbwr_get_auto_branch()
@@ -100,7 +101,7 @@ def test_mkl_cbwr_get_auto_branch(self):
100101
class test_miscellaneous():
101102
# https://software.intel.com/en-us/mkl-developer-reference-c-miscellaneous
102103
def test_mkl_enable_instructions(self):
103-
mkl_service.mkl_enable_instructions(mkl_service.MKL_ENABLE_AVX)
104+
mkl_service.mkl_enable_instructions(mkl_service.enums.MKL_ENABLE_AVX)
104105

105106
def test_mkl_set_env_mode(self):
106107
mkl_service.mkl_set_env_mode(0)
@@ -114,13 +115,13 @@ def test_mkl_set_mpi(self):
114115
class test_vm_service_functions():
115116
# https://software.intel.com/en-us/mkl-developer-reference-c-vm-service-functions
116117
def test_vmlSetMode(self):
117-
mkl_service.vmlSetMode(mkl_service.VML_HA)
118+
mkl_service.vmlSetMode(mkl_service.enums.VML_HA)
118119

119120
def test_vmlGetMode(self):
120121
mkl_service.vmlGetMode()
121122

122123
def test_vmlSetErrStatus(self):
123-
mkl_service.vmlSetErrStatus(mkl_service.VML_STATUS_OK)
124+
mkl_service.vmlSetErrStatus(mkl_service.enums.VML_STATUS_OK)
124125

125126
def test_vmlGetErrStatus(self):
126127
mkl_service.vmlGetErrStatus()

0 commit comments

Comments
 (0)