Skip to content

Commit 7c63669

Browse files
dmitrii-zagornyiGitHub Enterprise
authored andcommitted
Merge pull request #11 from SAT/feature/tc/mkl_changes
mkl_set_mpi and mkl_cbwr_set improvements
2 parents c9d9b89 + c1fcb98 commit 7c63669

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

mkl-service/_mkl_service.pyx

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

2626

27-
cimport _mkl_service as mkl
2827
import six
28+
cimport _mkl_service as mkl
2929

3030

3131
# Version Information
@@ -260,7 +260,7 @@ cpdef verbose(enable):
260260
return __verbose(enable)
261261

262262

263-
cpdef set_mpi(vendor, custom_library_name):
263+
cpdef set_mpi(vendor, custom_library_name=None):
264264
"""
265265
Sets the implementation of the message-passing interface to be used by Intel MKL.
266266
https://software.intel.com/en-us/mkl-developer-reference-c-mkl-set-mpi
@@ -617,6 +617,7 @@ cdef inline __cbwr_set(branch=None):
617617
"""
618618
__variables = {
619619
'input': {
620+
'off': mkl.MKL_CBWR_BRANCH_OFF,
620621
'auto': mkl.MKL_CBWR_AUTO,
621622
'compatible': mkl.MKL_CBWR_COMPATIBLE,
622623
'sse2': mkl.MKL_CBWR_SSE2,
@@ -655,6 +656,7 @@ cdef inline __cbwr_get(cnr_const=None):
655656
'all': mkl.MKL_CBWR_ALL,
656657
},
657658
'output': {
659+
mkl.MKL_CBWR_BRANCH_OFF: 'off',
658660
mkl.MKL_CBWR_AUTO: 'auto',
659661
mkl.MKL_CBWR_COMPATIBLE: 'compatible',
660662
mkl.MKL_CBWR_SSE2: 'sse2',
@@ -667,7 +669,6 @@ cdef inline __cbwr_get(cnr_const=None):
667669
mkl.MKL_CBWR_AVX512_MIC: 'avx512_mic',
668670
mkl.MKL_CBWR_AVX512: 'avx512',
669671
mkl.MKL_CBWR_SUCCESS: 'success',
670-
mkl.MKL_CBWR_BRANCH_OFF: 'branch_off',
671672
mkl.MKL_CBWR_ERR_INVALID_INPUT: 'err_invalid_input',
672673
},
673674
}
@@ -780,7 +781,7 @@ cdef inline __verbose(enable):
780781
return bool(mkl.mkl_verbose(enable))
781782

782783

783-
cdef inline __set_mpi(vendor, custom_library_name):
784+
cdef inline __set_mpi(vendor, custom_library_name=None):
784785
"""
785786
Sets the implementation of the message-passing interface to be used by Intel MKL.
786787
https://software.intel.com/en-us/mkl-developer-reference-c-mkl-set-mpi
@@ -799,10 +800,15 @@ cdef inline __set_mpi(vendor, custom_library_name):
799800
-3: 'MPI library cannot be set at this point',
800801
},
801802
}
803+
assert((vendor is not 'custom' and custom_library_name is None) or
804+
(vendor is 'custom' and custom_library_name is not None))
802805
mkl_vendor = __mkl_str_to_int(vendor, __variables['input'])
803806

804-
cdef bytes c_bytes = custom_library_name.encode()
805-
cdef char* c_string = c_bytes
807+
cdef bytes c_bytes
808+
cdef char* c_string = ''
809+
if custom_library_name is not None:
810+
c_bytes = custom_library_name.encode()
811+
c_string = c_bytes
806812
mkl_status = mkl.mkl_set_mpi(mkl_vendor, c_string)
807813

808814
status = __mkl_int_to_str(mkl_status, __variables['output'])

mkl-service/tests/test_mkl_service.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

2626

27+
from nose.tools import nottest
2728
import mkl
2829

2930

@@ -159,6 +160,9 @@ def test_set_memory_limit(self):
159160

160161
class test_conditional_numerical_reproducibility_control:
161162
# https://software.intel.com/en-us/mkl-developer-reference-c-conditional-numerical-reproducibility-control
163+
def test_cbwr_set_off(self):
164+
mkl.cbwr_set(branch='off')
165+
162166
def test_cbwr_set_auto(self):
163167
mkl.cbwr_set(branch='auto')
164168

@@ -234,17 +238,18 @@ def test_verbose_false(self):
234238
def test_verbose_true(self):
235239
mkl.verbose(True)
236240

237-
# def test_set_mpi_custom(self):
238-
# mkl.set_mpi('custom', 'test')
241+
def test_set_mpi_custom(self):
242+
mkl.set_mpi('custom', 'custom_library_name')
239243

240-
# def test_set_mpi_msmpi(self):
241-
# mkl.set_mpi('msmpi', 'test')
244+
@nottest
245+
def test_set_mpi_msmpi(self):
246+
mkl.set_mpi('msmpi')
242247

243-
# def test_set_mpi_intelmpi(self):
244-
# mkl.set_mpi('intelmpi', 'test')
248+
def test_set_mpi_intelmpi(self):
249+
mkl.set_mpi('intelmpi')
245250

246-
# def test_set_mpi_mpich2(self):
247-
# mkl.set_mpi('mpich2', 'test')
251+
def test_set_mpi_mpich2(self):
252+
mkl.set_mpi('mpich2')
248253

249254

250255
class test_vm_service_functions():

0 commit comments

Comments
 (0)