Skip to content

Commit 9aae656

Browse files
initial version of patching
1 parent 4542598 commit 9aae656

File tree

4 files changed

+185
-0
lines changed

4 files changed

+185
-0
lines changed

mkl_umath/setup.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,23 @@ def generate_umath_c(ext, build_dir):
156156
]
157157
)
158158

159+
from Cython.Build import cythonize
160+
from setuptools import Extension
161+
cythonize(Extension('_patch', sources=[join(wdir, 'patch.pyx'),]))
162+
163+
config.add_extension(
164+
name = '_patch',
165+
sources = [
166+
join(wdir, 'patch.c'),
167+
],
168+
libraries = mkl_libraries + ['loops_intel'],
169+
library_dirs = mkl_library_dirs,
170+
extra_compile_args = [
171+
# '-DNDEBUG',
172+
'-ggdb', '-O0', '-Wall', '-Wextra', '-DDEBUG',
173+
]
174+
)
175+
159176
config.add_data_dir('tests')
160177

161178
# if have_cython:

mkl_umath/src/loops_intel.c.src

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,3 +2468,41 @@ NPY_NO_EXPORT void
24682468
** END LOOPS **
24692469
*****************************************************************************
24702470
*/
2471+
2472+
PyUFuncGenericFunction get_func_by_name(char* s){
2473+
if (0) {
2474+
/**begin repeat
2475+
* #kind = sqrt, invsqrt, exp, exp2, expm1, erf, log, log2, log10, log1p, cos,
2476+
* sin, tan, arccos, arcsin, arctan, cosh, sinh, tanh, arccosh, arcsinh,
2477+
* arctanh, fabs, floor, ceil, rint, trunc, cbrt, add, subtract, multiply,
2478+
* divide, equal, not_equal, less, less_equal, greater, greater_equal,
2479+
* logical_and, logical_or, logical_xor, logical_not, isnan, isinf, isfinite,
2480+
* signbit, spacing, copysign, nextafter, maximum, minimum, fmax, fmin,
2481+
* floor_divide, remainder, divmod, square, reciprocal, conjugate, absolute,
2482+
* negative, positive, sign, modf, frexp, ldexp, ldexp_long, true_divide#
2483+
*/
2484+
/**begin repeat1
2485+
* #TYPE = FLOAT, DOUBLE#
2486+
*/
2487+
} else if (strcmp(s, "@TYPE@_@kind@") == 0) {
2488+
return @TYPE@_@kind@;
2489+
/**end repeat1**/
2490+
/**end repeat**/
2491+
2492+
/**begin repeat
2493+
* #kind = add, subtract, multiply, divide, floor_divide, greater,
2494+
* greater_equal, less, less_equal, equal, not_equal, logical_and, logical_or,
2495+
* logical_xor, logical_not, isnan, isinf, isfinite, square, reciprocal,
2496+
* conjugate, absolute, sign, maximum, minimum, fmax, fmin, true_divide#
2497+
*/
2498+
/**begin repeat1
2499+
* #TYPE = CFLOAT, CDOUBLE#
2500+
*/
2501+
} else if (strcmp(s, "@TYPE@_@kind@") == 0) {
2502+
return @TYPE@_@kind@;
2503+
/**end repeat1**/
2504+
/**end repeat**/
2505+
} else {
2506+
printf("Error! Unknown function in get_func_by_name: %s\n", s);
2507+
}
2508+
}

mkl_umath/src/loops_intel.h.src

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include "numpy/ndarraytypes.h"
55

6+
#include <string.h>
7+
68
/**begin repeat
79
* Float types
810
* #TYPE = FLOAT, DOUBLE#
@@ -274,4 +276,6 @@ NPY_NO_EXPORT void
274276
#undef CEQ
275277
#undef CNE
276278

279+
PyUFuncGenericFunction get_func_by_name(char* s);
280+
277281
#endif

mkl_umath/src/patch.pyx

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#*******************************************************************************
2+
# Copyright 2014-2020 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#******************************************************************************/
16+
17+
# distutils: language = c
18+
# cython: language_level=2
19+
20+
import mkl_umath._ufuncs as mu
21+
import numpy.core.umath as nu
22+
23+
cimport numpy as cnp
24+
import numpy as np
25+
26+
from libc.stdlib cimport malloc, free
27+
28+
cdef extern from "loops_intel.h":
29+
cnp.PyUFuncGenericFunction get_func_by_name(char*)
30+
31+
cnp.import_umath()
32+
33+
def _get_func_name(name, type):
34+
if type.startswith('f'):
35+
type_str = 'FLOAT'
36+
elif type.startswith('d'):
37+
type_str = 'DOUBLE'
38+
elif type.startswith('F'):
39+
type_str = 'CFLOAT'
40+
elif type.startswith('D'):
41+
type_str = 'CDOUBLE'
42+
else:
43+
raise ValueError("_get_func_name: Unexpected type specified!")
44+
func_name = type_str + '_' + name
45+
if type.startswith('fl') or type.startswith('dl'):
46+
func_name = func_name + '_long'
47+
return func_name
48+
49+
50+
cdef void _fill_signature(signature_str, int* signature):
51+
for i in range(len(signature_str)):
52+
if signature_str[i] == 'f':
53+
signature[i] = cnp.NPY_FLOAT
54+
elif signature_str[i] == 'd':
55+
signature[i] = cnp.NPY_DOUBLE
56+
elif signature_str[i] == 'F':
57+
signature[i] = cnp.NPY_CFLOAT
58+
elif signature_str[i] == 'D':
59+
signature[i] = cnp.NPY_CDOUBLE
60+
elif signature_str[i] == 'i':
61+
signature[i] = cnp.NPY_INT
62+
elif signature_str[i] == 'l':
63+
signature[i] = cnp.NPY_LONG
64+
elif signature_str[i] == '?':
65+
signature[i] = cnp.NPY_BOOL
66+
else:
67+
raise ValueError("_fill_signature: Unexpected type specified!")
68+
69+
70+
cdef cnp.PyUFuncGenericFunction fooSaved
71+
72+
funcs_dict = {}
73+
cdef cnp.PyUFuncGenericFunction* originalFuncs
74+
75+
76+
cdef c_do_patch():
77+
cdef int res
78+
79+
global originalFuncs
80+
81+
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
82+
func_number = 0
83+
for umath in umaths:
84+
mkl_umath = getattr(mu, umath)
85+
types = mkl_umath.types
86+
for type in types:
87+
funcs_dict[(umath, type)] = func_number
88+
func_number = func_number + 1
89+
originalFuncs = <cnp.PyUFuncGenericFunction *> malloc(len(funcs_dict) * sizeof(cnp.PyUFuncGenericFunction))
90+
91+
for func in funcs_dict:
92+
umath = func[0]
93+
type = func[1]
94+
np_umath = getattr(nu, umath)
95+
signature_str = type.replace('->', '')
96+
signature = <int *> malloc(len(signature_str) * sizeof(int))
97+
_fill_signature(signature_str, signature)
98+
ufunc_name = _get_func_name(umath, type)
99+
ufunc = get_func_by_name(str.encode(ufunc_name))
100+
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, ufunc, signature, &(originalFuncs[funcs_dict[func]]))
101+
free(signature)
102+
103+
104+
cdef c_do_unpatch():
105+
cdef int res
106+
cdef cnp.PyUFuncGenericFunction temp
107+
108+
global originalFuncs
109+
110+
for func in funcs_dict:
111+
umath = func[0]
112+
type = func[1]
113+
np_umath = getattr(nu, umath)
114+
signature_str = type.replace('->', '')
115+
signature = <int *> malloc(len(signature_str) * sizeof(int))
116+
_fill_signature(signature_str, signature)
117+
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, originalFuncs[funcs_dict[(umath, type)]], signature, &temp)
118+
free(signature)
119+
free(originalFuncs)
120+
121+
122+
def do_patch():
123+
c_do_patch()
124+
125+
def do_unpatch():
126+
c_do_unpatch()

0 commit comments

Comments
 (0)