Skip to content

Commit d2fca85

Browse files
patching with native ufunc struct
1 parent 9aae656 commit d2fca85

File tree

3 files changed

+53
-108
lines changed

3 files changed

+53
-108
lines changed

mkl_umath/src/loops_intel.c.src

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,41 +2468,3 @@ NPY_NO_EXPORT void
24682468
** END LOOPS **
24692469
*****************************************************************************
24702470
*/
2471-
2472-
PyUFuncGenericFunction get_func_by_name(char* s){
2473-
if (0) {
2474-
/**begin repeat
2475-
* #kind = sqrt, invsqrt, exp, exp2, expm1, erf, log, log2, log10, log1p, cos,
2476-
* sin, tan, arccos, arcsin, arctan, cosh, sinh, tanh, arccosh, arcsinh,
2477-
* arctanh, fabs, floor, ceil, rint, trunc, cbrt, add, subtract, multiply,
2478-
* divide, equal, not_equal, less, less_equal, greater, greater_equal,
2479-
* logical_and, logical_or, logical_xor, logical_not, isnan, isinf, isfinite,
2480-
* signbit, spacing, copysign, nextafter, maximum, minimum, fmax, fmin,
2481-
* floor_divide, remainder, divmod, square, reciprocal, conjugate, absolute,
2482-
* negative, positive, sign, modf, frexp, ldexp, ldexp_long, true_divide#
2483-
*/
2484-
/**begin repeat1
2485-
* #TYPE = FLOAT, DOUBLE#
2486-
*/
2487-
} else if (strcmp(s, "@TYPE@_@kind@") == 0) {
2488-
return @TYPE@_@kind@;
2489-
/**end repeat1**/
2490-
/**end repeat**/
2491-
2492-
/**begin repeat
2493-
* #kind = add, subtract, multiply, divide, floor_divide, greater,
2494-
* greater_equal, less, less_equal, equal, not_equal, logical_and, logical_or,
2495-
* logical_xor, logical_not, isnan, isinf, isfinite, square, reciprocal,
2496-
* conjugate, absolute, sign, maximum, minimum, fmax, fmin, true_divide#
2497-
*/
2498-
/**begin repeat1
2499-
* #TYPE = CFLOAT, CDOUBLE#
2500-
*/
2501-
} else if (strcmp(s, "@TYPE@_@kind@") == 0) {
2502-
return @TYPE@_@kind@;
2503-
/**end repeat1**/
2504-
/**end repeat**/
2505-
} else {
2506-
printf("Error! Unknown function in get_func_by_name: %s\n", s);
2507-
}
2508-
}

mkl_umath/src/loops_intel.h.src

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,4 @@ NPY_NO_EXPORT void
276276
#undef CEQ
277277
#undef CNE
278278

279-
PyUFuncGenericFunction get_func_by_name(char* s);
280-
281279
#endif

mkl_umath/src/patch.pyx

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,98 +25,83 @@ import numpy as np
2525

2626
from libc.stdlib cimport malloc, free
2727

28-
cdef extern from "loops_intel.h":
29-
cnp.PyUFuncGenericFunction get_func_by_name(char*)
30-
3128
cnp.import_umath()
3229

33-
def _get_func_name(name, type):
34-
if type.startswith('f'):
35-
type_str = 'FLOAT'
36-
elif type.startswith('d'):
37-
type_str = 'DOUBLE'
38-
elif type.startswith('F'):
39-
type_str = 'CFLOAT'
40-
elif type.startswith('D'):
41-
type_str = 'CDOUBLE'
42-
else:
43-
raise ValueError("_get_func_name: Unexpected type specified!")
44-
func_name = type_str + '_' + name
45-
if type.startswith('fl') or type.startswith('dl'):
46-
func_name = func_name + '_long'
47-
return func_name
48-
49-
50-
cdef void _fill_signature(signature_str, int* signature):
51-
for i in range(len(signature_str)):
52-
if signature_str[i] == 'f':
53-
signature[i] = cnp.NPY_FLOAT
54-
elif signature_str[i] == 'd':
55-
signature[i] = cnp.NPY_DOUBLE
56-
elif signature_str[i] == 'F':
57-
signature[i] = cnp.NPY_CFLOAT
58-
elif signature_str[i] == 'D':
59-
signature[i] = cnp.NPY_CDOUBLE
60-
elif signature_str[i] == 'i':
61-
signature[i] = cnp.NPY_INT
62-
elif signature_str[i] == 'l':
63-
signature[i] = cnp.NPY_LONG
64-
elif signature_str[i] == '?':
65-
signature[i] = cnp.NPY_BOOL
66-
else:
67-
raise ValueError("_fill_signature: Unexpected type specified!")
68-
69-
70-
cdef cnp.PyUFuncGenericFunction fooSaved
71-
7230
funcs_dict = {}
73-
cdef cnp.PyUFuncGenericFunction* originalFuncs
7431

32+
ctypedef struct function_info:
33+
cnp.PyUFuncGenericFunction np_function
34+
cnp.PyUFuncGenericFunction mkl_function
35+
int* signature
7536

76-
cdef c_do_patch():
77-
cdef int res
37+
cdef function_info* functions
7838

79-
global originalFuncs
39+
def fill_functions():
40+
global functions
8041

8142
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
82-
func_number = 0
43+
funcs_count = 0
8344
for umath in umaths:
8445
mkl_umath = getattr(mu, umath)
8546
types = mkl_umath.types
8647
for type in types:
48+
funcs_count = funcs_count + 1
49+
50+
functions = <function_info *> malloc(funcs_count * sizeof(function_info))
51+
52+
func_number = 0
53+
for umath in umaths:
54+
mkl_umath = getattr(mu, umath)
55+
np_umath = getattr(nu, umath)
56+
c_mkl_umath = <cnp.ufunc>mkl_umath
57+
c_np_umath = <cnp.ufunc>np_umath
58+
for type in mkl_umath.types:
59+
np_index = np_umath.types.index(type)
60+
functions[func_number].np_function = c_np_umath.functions[np_index]
61+
mkl_index = mkl_umath.types.index(type)
62+
functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
63+
64+
nargs = c_mkl_umath.nargs
65+
functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
66+
for i in range(nargs):
67+
functions[func_number].signature[i] = c_mkl_umath.types[mkl_index*nargs + i]
68+
8769
funcs_dict[(umath, type)] = func_number
8870
func_number = func_number + 1
89-
originalFuncs = <cnp.PyUFuncGenericFunction *> malloc(len(funcs_dict) * sizeof(cnp.PyUFuncGenericFunction))
71+
72+
73+
fill_functions()
74+
75+
cdef c_do_patch():
76+
cdef int res
77+
cdef cnp.PyUFuncGenericFunction temp
78+
cdef cnp.PyUFuncGenericFunction function
79+
cdef int* signature
80+
81+
global functions
9082

9183
for func in funcs_dict:
92-
umath = func[0]
93-
type = func[1]
94-
np_umath = getattr(nu, umath)
95-
signature_str = type.replace('->', '')
96-
signature = <int *> malloc(len(signature_str) * sizeof(int))
97-
_fill_signature(signature_str, signature)
98-
ufunc_name = _get_func_name(umath, type)
99-
ufunc = get_func_by_name(str.encode(ufunc_name))
100-
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, ufunc, signature, &(originalFuncs[funcs_dict[func]]))
101-
free(signature)
84+
np_umath = getattr(nu, func[0])
85+
index = funcs_dict[func]
86+
function = functions[index].mkl_function
87+
signature = functions[index].signature
88+
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
10289

10390

10491
cdef c_do_unpatch():
10592
cdef int res
10693
cdef cnp.PyUFuncGenericFunction temp
94+
cdef cnp.PyUFuncGenericFunction function
95+
cdef int* signature
10796

108-
global originalFuncs
97+
global functions
10998

11099
for func in funcs_dict:
111-
umath = func[0]
112-
type = func[1]
113-
np_umath = getattr(nu, umath)
114-
signature_str = type.replace('->', '')
115-
signature = <int *> malloc(len(signature_str) * sizeof(int))
116-
_fill_signature(signature_str, signature)
117-
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, originalFuncs[funcs_dict[(umath, type)]], signature, &temp)
118-
free(signature)
119-
free(originalFuncs)
100+
np_umath = getattr(nu, func[0])
101+
index = funcs_dict[func]
102+
function = functions[index].np_function
103+
signature = functions[index].signature
104+
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
120105

121106

122107
def do_patch():

0 commit comments

Comments
 (0)