|
1 | 1 | cimport _mkl_service as mkl
|
2 |
| - |
3 |
| - |
4 |
| -# MKL_INT64 mkl_peak_mem_usage(int mode) |
5 |
| -# In |
6 |
| -MKL_PEAK_MEM_ENABLE = mkl.MKL_PEAK_MEM_ENABLE |
7 |
| -MKL_PEAK_MEM_DISABLE = mkl.MKL_PEAK_MEM_DISABLE |
8 |
| -MKL_PEAK_MEM = mkl.MKL_PEAK_MEM |
9 |
| -MKL_PEAK_MEM_RESET = mkl.MKL_PEAK_MEM_RESET |
10 |
| - |
11 |
| -# int mkl_set_memory_limit(int mem_type, size_t limit) |
12 |
| -# In |
13 |
| -MKL_MEM_MCDRAM = mkl.MKL_MEM_MCDRAM |
14 |
| - |
15 |
| -# int mkl_cbwr_set(int settings) |
16 |
| -# In |
17 |
| -MKL_CBWR_AUTO = mkl.MKL_CBWR_AUTO |
18 |
| -MKL_CBWR_COMPATIBLE = mkl.MKL_CBWR_COMPATIBLE |
19 |
| -MKL_CBWR_SSE2 = mkl.MKL_CBWR_SSE2 |
20 |
| -MKL_CBWR_SSE3 = mkl.MKL_CBWR_SSE3 |
21 |
| -MKL_CBWR_SSSE3 = mkl.MKL_CBWR_SSSE3 |
22 |
| -MKL_CBWR_SSE4_1 = mkl.MKL_CBWR_SSE4_1 |
23 |
| -MKL_CBWR_SSE4_2 = mkl.MKL_CBWR_SSE4_2 |
24 |
| -MKL_CBWR_AVX = mkl.MKL_CBWR_AVX |
25 |
| -MKL_CBWR_AVX2 = mkl.MKL_CBWR_AVX2 |
26 |
| -MKL_CBWR_AVX512_MIC = mkl.MKL_CBWR_AVX512_MIC |
27 |
| -MKL_CBWR_AVX512 = mkl.MKL_CBWR_AVX512_MIC |
28 |
| -# Out |
29 |
| -MKL_CBWR_SUCCESS = mkl.MKL_CBWR_SUCCESS |
30 |
| -MKL_CBWR_ERR_INVALID_INPUT = mkl.MKL_CBWR_ERR_INVALID_INPUT |
31 |
| -MKL_CBWR_ERR_UNSUPPORTED_BRANCH = mkl.MKL_CBWR_ERR_UNSUPPORTED_BRANCH |
32 |
| -MKL_CBWR_ERR_MODE_CHANGE_FAILURE = mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE |
33 |
| - |
34 |
| -# int mkl_cbwr_get(int option) |
35 |
| -# In |
36 |
| -MKL_CBWR_BRANCH = mkl.MKL_CBWR_BRANCH |
37 |
| -MKL_CBWR_ALL = mkl.MKL_CBWR_ALL |
38 |
| - |
39 |
| -# int mkl_enable_instructions(int isa) |
40 |
| -# In |
41 |
| -MKL_ENABLE_AVX512 = mkl.MKL_ENABLE_AVX512 |
42 |
| -MKL_ENABLE_AVX512_MIC = mkl.MKL_ENABLE_AVX512_MIC |
43 |
| -MKL_ENABLE_AVX2 = mkl.MKL_ENABLE_AVX2 |
44 |
| -MKL_ENABLE_AVX = mkl.MKL_ENABLE_AVX |
45 |
| -MKL_ENABLE_SSE4_2 = mkl.MKL_ENABLE_SSE4_2 |
46 |
| - |
47 |
| -# unsigned int vmlSetMode(unsigned int mode) |
48 |
| -# In |
49 |
| -VML_HA = mkl.VML_HA |
50 |
| -VML_LA = mkl.VML_LA |
51 |
| -VML_EP = mkl.VML_EP |
52 |
| -VML_FTZDAZ_ON = mkl.VML_FTZDAZ_ON |
53 |
| -VML_FTZDAZ_OFF = mkl.VML_FTZDAZ_OFF |
54 |
| -VML_ERRMODE_IGNORE = mkl.VML_ERRMODE_IGNORE |
55 |
| -VML_ERRMODE_ERRNO = mkl.VML_ERRMODE_ERRNO |
56 |
| -VML_ERRMODE_STDERR = mkl.VML_ERRMODE_STDERR |
57 |
| -VML_ERRMODE_EXCEPT = mkl.VML_ERRMODE_EXCEPT |
58 |
| -VML_ERRMODE_CALLBACK = mkl.VML_ERRMODE_CALLBACK |
59 |
| -VML_ERRMODE_DEFAULT = mkl.VML_ERRMODE_DEFAULT |
60 |
| - |
61 |
| -# int vmlSetErrStatus(const MKL_INT status) |
62 |
| -# In |
63 |
| -VML_STATUS_OK = mkl.VML_STATUS_OK |
64 |
| -VML_STATUS_ACCURACYWARNING = mkl.VML_STATUS_ACCURACYWARNING |
65 |
| -VML_STATUS_BADSIZE = mkl.VML_STATUS_BADSIZE |
66 |
| -VML_STATUS_BADMEM = mkl.VML_STATUS_BADMEM |
67 |
| -VML_STATUS_ERRDOM = mkl.VML_STATUS_ERRDOM |
68 |
| -VML_STATUS_SING = mkl.VML_STATUS_SING |
69 |
| -VML_STATUS_OVERFLOW = mkl.VML_STATUS_OVERFLOW |
70 |
| -VML_STATUS_UNDERFLOW = mkl.VML_STATUS_UNDERFLOW |
| 2 | +from enum import IntEnum |
| 3 | + |
| 4 | + |
| 5 | +class enums(IntEnum): |
| 6 | + # MKL Function Domains |
| 7 | + MKL_DOMAIN_BLAS = mkl.MKL_DOMAIN_BLAS |
| 8 | + MKL_DOMAIN_FFT = mkl.MKL_DOMAIN_FFT |
| 9 | + MKL_DOMAIN_VML = mkl.MKL_DOMAIN_VML |
| 10 | + MKL_DOMAIN_PARDISO = mkl.MKL_DOMAIN_PARDISO |
| 11 | + MKL_DOMAIN_ALL = mkl.MKL_DOMAIN_ALL |
| 12 | + |
| 13 | + # MKL_INT64 mkl_peak_mem_usage(int mode) |
| 14 | + # In |
| 15 | + MKL_PEAK_MEM_ENABLE = mkl.MKL_PEAK_MEM_ENABLE |
| 16 | + MKL_PEAK_MEM_DISABLE = mkl.MKL_PEAK_MEM_DISABLE |
| 17 | + MKL_PEAK_MEM = mkl.MKL_PEAK_MEM |
| 18 | + MKL_PEAK_MEM_RESET = mkl.MKL_PEAK_MEM_RESET |
| 19 | + |
| 20 | + # int mkl_set_memory_limit(int mem_type, size_t limit) |
| 21 | + # In |
| 22 | + MKL_MEM_MCDRAM = mkl.MKL_MEM_MCDRAM |
| 23 | + |
| 24 | + # CNR Control Constants |
| 25 | + MKL_CBWR_AUTO = mkl.MKL_CBWR_AUTO |
| 26 | + MKL_CBWR_COMPATIBLE = mkl.MKL_CBWR_COMPATIBLE |
| 27 | + MKL_CBWR_SSE2 = mkl.MKL_CBWR_SSE2 |
| 28 | + MKL_CBWR_SSE3 = mkl.MKL_CBWR_SSE3 |
| 29 | + MKL_CBWR_SSSE3 = mkl.MKL_CBWR_SSSE3 |
| 30 | + MKL_CBWR_SSE4_1 = mkl.MKL_CBWR_SSE4_1 |
| 31 | + MKL_CBWR_SSE4_2 = mkl.MKL_CBWR_SSE4_2 |
| 32 | + MKL_CBWR_AVX = mkl.MKL_CBWR_AVX |
| 33 | + MKL_CBWR_AVX2 = mkl.MKL_CBWR_AVX2 |
| 34 | + MKL_CBWR_AVX512_MIC = mkl.MKL_CBWR_AVX512_MIC |
| 35 | + MKL_CBWR_AVX512 = mkl.MKL_CBWR_AVX512_MIC |
| 36 | + MKL_CBWR_BRANCH = mkl.MKL_CBWR_BRANCH |
| 37 | + MKL_CBWR_ALL = mkl.MKL_CBWR_ALL |
| 38 | + MKL_CBWR_SUCCESS = mkl.MKL_CBWR_SUCCESS |
| 39 | + MKL_CBWR_ERR_INVALID_INPUT = mkl.MKL_CBWR_ERR_INVALID_INPUT |
| 40 | + MKL_CBWR_ERR_UNSUPPORTED_BRANCH = mkl.MKL_CBWR_ERR_UNSUPPORTED_BRANCH |
| 41 | + MKL_CBWR_ERR_MODE_CHANGE_FAILURE = mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE |
| 42 | + |
| 43 | + # int mkl_enable_instructions(int isa) |
| 44 | + # In |
| 45 | + MKL_ENABLE_AVX512 = mkl.MKL_ENABLE_AVX512 |
| 46 | + MKL_ENABLE_AVX512_MIC = mkl.MKL_ENABLE_AVX512_MIC |
| 47 | + MKL_ENABLE_AVX2 = mkl.MKL_ENABLE_AVX2 |
| 48 | + MKL_ENABLE_AVX = mkl.MKL_ENABLE_AVX |
| 49 | + MKL_ENABLE_SSE4_2 = mkl.MKL_ENABLE_SSE4_2 |
| 50 | + |
| 51 | + # unsigned int vmlSetMode(unsigned int mode) |
| 52 | + # In |
| 53 | + VML_HA = mkl.VML_HA |
| 54 | + VML_LA = mkl.VML_LA |
| 55 | + VML_EP = mkl.VML_EP |
| 56 | + VML_FTZDAZ_ON = mkl.VML_FTZDAZ_ON |
| 57 | + VML_FTZDAZ_OFF = mkl.VML_FTZDAZ_OFF |
| 58 | + VML_ERRMODE_IGNORE = mkl.VML_ERRMODE_IGNORE |
| 59 | + VML_ERRMODE_ERRNO = mkl.VML_ERRMODE_ERRNO |
| 60 | + VML_ERRMODE_STDERR = mkl.VML_ERRMODE_STDERR |
| 61 | + VML_ERRMODE_EXCEPT = mkl.VML_ERRMODE_EXCEPT |
| 62 | + VML_ERRMODE_CALLBACK = mkl.VML_ERRMODE_CALLBACK |
| 63 | + VML_ERRMODE_DEFAULT = mkl.VML_ERRMODE_DEFAULT |
| 64 | + |
| 65 | + # int vmlSetErrStatus(const MKL_INT status) |
| 66 | + # In |
| 67 | + VML_STATUS_OK = mkl.VML_STATUS_OK |
| 68 | + VML_STATUS_ACCURACYWARNING = mkl.VML_STATUS_ACCURACYWARNING |
| 69 | + VML_STATUS_BADSIZE = mkl.VML_STATUS_BADSIZE |
| 70 | + VML_STATUS_BADMEM = mkl.VML_STATUS_BADMEM |
| 71 | + VML_STATUS_ERRDOM = mkl.VML_STATUS_ERRDOM |
| 72 | + VML_STATUS_SING = mkl.VML_STATUS_SING |
| 73 | + VML_STATUS_OVERFLOW = mkl.VML_STATUS_OVERFLOW |
| 74 | + VML_STATUS_UNDERFLOW = mkl.VML_STATUS_UNDERFLOW |
| 75 | + |
| 76 | +# MKL Function Domains |
| 77 | +__mkl_domain_enums = {'blas': mkl.MKL_DOMAIN_BLAS, |
| 78 | + 'fft': mkl.MKL_DOMAIN_FFT, |
| 79 | + 'vml': mkl.MKL_DOMAIN_VML, |
| 80 | + 'pardiso': mkl.MKL_DOMAIN_PARDISO, |
| 81 | + 'all': mkl.MKL_DOMAIN_ALL} |
| 82 | + |
| 83 | +# CNR Control Constants |
| 84 | +__mkl_cbwr_set_in_enums = {'auto': mkl.MKL_CBWR_AUTO, |
| 85 | + 'compatible': mkl.MKL_CBWR_COMPATIBLE, |
| 86 | + 'sse2': mkl.MKL_CBWR_SSE2, |
| 87 | + 'sse3': mkl.MKL_CBWR_SSE3, |
| 88 | + 'ssse3': mkl.MKL_CBWR_SSSE3, |
| 89 | + 'sse4_1': mkl.MKL_CBWR_SSE4_1, |
| 90 | + 'sse4_2': mkl.MKL_CBWR_SSE4_2, |
| 91 | + 'avx': mkl.MKL_CBWR_AVX, |
| 92 | + 'avx2': mkl.MKL_CBWR_AVX2, |
| 93 | + 'avx512_mic': mkl.MKL_CBWR_AVX512_MIC, |
| 94 | + 'avx512': mkl.MKL_CBWR_AVX512} |
| 95 | + |
| 96 | +__mkl_cbwr_set_out_enums = {mkl.MKL_CBWR_SUCCESS: 'success', |
| 97 | + mkl.MKL_CBWR_ERR_INVALID_INPUT: 'err_invalid_input', |
| 98 | + mkl.MKL_CBWR_ERR_UNSUPPORTED_BRANCH: 'err_unsupported_branch', |
| 99 | + mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE: 'err_mode_change_failure'} |
| 100 | + |
| 101 | +__mkl_cbwr_get_in_enums = {'branch': mkl.MKL_CBWR_BRANCH, |
| 102 | + 'all': mkl.MKL_CBWR_ALL} |
| 103 | + |
| 104 | +__mkl_cbwr_get_out_enums = {mkl.MKL_CBWR_SUCCESS: 'success', |
| 105 | + mkl.MKL_CBWR_ERR_INVALID_INPUT: 'err_invalid_input'} |
| 106 | +__mkl_cbwr_get_out_enums.update({value: key for key, value in __mkl_cbwr_set_in_enums.items()}) |
| 107 | + |
| 108 | +__mkl_cbwr_get_auto_branch_out_enums = {} |
| 109 | +__mkl_cbwr_get_auto_branch_out_enums.update({value: key for key, value in __mkl_cbwr_set_in_enums.items()}) |
71 | 110 |
|
72 | 111 |
|
73 | 112 | '''
|
@@ -98,35 +137,92 @@ def mkl_get_version_string():
|
98 | 137 | int mkl_domain_get_max_threads(int domain)
|
99 | 138 | int mkl_get_dynamic()
|
100 | 139 | '''
|
101 |
| -def mkl_set_num_threads(nt): |
102 |
| - mkl.mkl_set_num_threads(nt) |
| 140 | +def mkl_set_num_threads(num_threads): |
| 141 | + assert(type(num_threads) is int) |
| 142 | + assert(num_threads > 0) |
| 143 | + |
| 144 | + prev_num_threads = mkl_get_max_threads() |
| 145 | + assert(type(prev_num_threads) is int) |
| 146 | + assert(prev_num_threads > 0) |
| 147 | + |
| 148 | + mkl.mkl_set_num_threads(num_threads) |
| 149 | + |
| 150 | + return prev_num_threads |
| 151 | + |
103 | 152 |
|
| 153 | +def mkl_domain_set_num_threads(num_threads, domain='all'): |
| 154 | + assert(type(num_threads) is int) |
| 155 | + assert(num_threads >= 0) |
| 156 | + domain_type = type(domain) |
| 157 | + if domain_type is str: |
| 158 | + assert(domain in __mkl_domain_enums.keys()) |
| 159 | + domain = __mkl_domain_enums[domain] |
| 160 | + else: |
| 161 | + assert((domain_type is int) and (domain in __mkl_domain_enums.values())) |
104 | 162 |
|
105 |
| -def mkl_domain_set_num_threads(nth, domain): |
106 |
| - return mkl.mkl_domain_set_num_threads(nth, domain) |
| 163 | + status = mkl.mkl_domain_set_num_threads(num_threads, domain) |
| 164 | + assert((status == 0) or (status == 1)) |
107 | 165 |
|
| 166 | + if (status == 1): |
| 167 | + status = 'success' |
| 168 | + else: |
| 169 | + status = 'error' |
108 | 170 |
|
109 |
| -def mkl_set_num_threads_local(nth): |
110 |
| - return mkl.mkl_set_num_threads_local(nth) |
| 171 | + return status |
| 172 | + |
| 173 | + |
| 174 | +def mkl_set_num_threads_local(num_threads): |
| 175 | + assert(type(num_threads) is int) |
| 176 | + assert(num_threads >= 0) |
| 177 | + status = mkl.mkl_set_num_threads_local(num_threads) |
| 178 | + assert(status >= 0) |
| 179 | + |
| 180 | + if (status == 0): |
| 181 | + status = 'global_num_threads' |
| 182 | + |
| 183 | + return status |
| 184 | + |
| 185 | +def mkl_set_dynamic(enable): |
| 186 | + assert(type(enable) is bool) |
| 187 | + if enable: |
| 188 | + enable = 1 |
| 189 | + else: |
| 190 | + enable = 0 |
111 | 191 |
|
| 192 | + mkl.mkl_set_dynamic(enable) |
112 | 193 |
|
113 |
| -def mkl_set_dynamic(flag): |
114 |
| - mkl.mkl_set_dynamic(flag) |
| 194 | + return mkl_get_max_threads() |
115 | 195 |
|
116 | 196 |
|
117 | 197 | def mkl_get_max_threads():
|
118 |
| - return mkl.mkl_get_max_threads() |
| 198 | + num_threads = mkl.mkl_get_max_threads() |
| 199 | + assert(type(num_threads) is int) |
| 200 | + assert(num_threads >= 1) |
119 | 201 |
|
| 202 | + return num_threads |
120 | 203 |
|
121 |
| -def mkl_domain_get_max_threads(domain): |
122 |
| - return mkl.mkl_domain_get_max_threads(domain) |
| 204 | + |
| 205 | +def mkl_domain_get_max_threads(domain='all'): |
| 206 | + domain_type = type(domain) |
| 207 | + if domain_type is str: |
| 208 | + assert(domain in __mkl_domain_enums.keys()) |
| 209 | + domain = __mkl_domain_enums[domain] |
| 210 | + else: |
| 211 | + assert((domain_type is int) and (domain in __mkl_domain_enums.values())) |
| 212 | + |
| 213 | + num_threads = mkl.mkl_domain_get_max_threads(domain) |
| 214 | + assert(type(num_threads) is int) |
| 215 | + assert(num_threads >= 1) |
| 216 | + |
| 217 | + return num_threads |
123 | 218 |
|
124 | 219 |
|
125 | 220 | def mkl_get_dynamic():
|
| 221 | + dynamic_enabled = mkl.mkl_get_dynamic() |
| 222 | + assert((dynamic_enabled == 0) or (dynamic_enabled == 1)) |
126 | 223 | return mkl.mkl_get_dynamic()
|
127 | 224 |
|
128 | 225 |
|
129 |
| - |
130 | 226 | '''
|
131 | 227 | # Timing
|
132 | 228 | float second()
|
@@ -200,19 +296,47 @@ def mkl_set_memory_limit(mem_type, limit):
|
200 | 296 |
|
201 | 297 |
|
202 | 298 | '''
|
203 |
| - #Conditional Numerical Reproducibility |
| 299 | + # Conditional Numerical Reproducibility |
204 | 300 | int mkl_cbwr_set(int settings)
|
205 |
| - int mkl_cbwr_get() |
| 301 | + int mkl_cbwr_get(int option) |
206 | 302 | int mkl_cbwr_get_auto_branch()
|
207 | 303 | '''
|
208 |
| -def mkl_cbwr_set(settings): |
209 |
| - return mkl.mkl_cbwr_set(settings) |
| 304 | +def mkl_cbwr_set(branch=''): |
| 305 | + branch_type = type(branch) |
| 306 | + if branch_type is str: |
| 307 | + assert(branch in __mkl_cbwr_set_in_enums.keys()) |
| 308 | + branch = __mkl_cbwr_set_in_enums[branch] |
| 309 | + else: |
| 310 | + assert((branch_type is int) and (branch in __mkl_cbwr_set_in_enums.values())) |
| 311 | + |
| 312 | + status = mkl.mkl_cbwr_set(branch) |
| 313 | + assert(status in __mkl_cbwr_set_out_enums.keys()) |
| 314 | + |
| 315 | + return __mkl_cbwr_set_out_enums[status] |
| 316 | + |
| 317 | + |
| 318 | +def mkl_cbwr_get(cnr_const=''): |
| 319 | + cnr_const_type = type(cnr_const) |
| 320 | + if cnr_const_type is str: |
| 321 | + assert(cnr_const in __mkl_cbwr_get_in_enums) |
| 322 | + cnr_const = __mkl_cbwr_get_in_enums[cnr_const] |
| 323 | + else: |
| 324 | + assert(issubclass(cnr_const_type, IntEnum)) |
| 325 | + assert(type(cnr_const.value) is int) |
| 326 | + assert(cnr_const.value in __mkl_cbwr_get_in_enums.values()) |
| 327 | + cnr_const = cnr_const.value |
| 328 | + |
| 329 | + status = mkl.mkl_cbwr_get(cnr_const) |
| 330 | + assert(status in __mkl_cbwr_get_out_enums) |
| 331 | + |
| 332 | + return __mkl_cbwr_get_out_enums[status] |
210 | 333 |
|
211 |
| -def mkl_cbwr_get(option): |
212 |
| - return mkl.mkl_cbwr_get(option) |
213 | 334 |
|
214 | 335 | def mkl_cbwr_get_auto_branch():
|
215 |
| - return mkl.mkl_cbwr_get_auto_branch() |
| 336 | + status = mkl.mkl_cbwr_get_auto_branch() |
| 337 | + assert(status in __mkl_cbwr_get_auto_branch_out_enums) |
| 338 | + |
| 339 | + return __mkl_cbwr_get_auto_branch_out_enums[status] |
216 | 340 |
|
217 | 341 |
|
218 | 342 | '''
|
|
0 commit comments