@@ -25,98 +25,83 @@ import numpy as np
25
25
26
26
from libc.stdlib cimport malloc, free
27
27
28
- cdef extern from " loops_intel.h" :
29
- cnp.PyUFuncGenericFunction get_func_by_name(char * )
30
-
31
28
cnp.import_umath()
32
29
33
- def _get_func_name (name , type ):
34
- if type .startswith(' f' ):
35
- type_str = ' FLOAT'
36
- elif type .startswith(' d' ):
37
- type_str = ' DOUBLE'
38
- elif type .startswith(' F' ):
39
- type_str = ' CFLOAT'
40
- elif type .startswith(' D' ):
41
- type_str = ' CDOUBLE'
42
- else :
43
- raise ValueError (" _get_func_name: Unexpected type specified!" )
44
- func_name = type_str + ' _' + name
45
- if type .startswith(' fl' ) or type .startswith(' dl' ):
46
- func_name = func_name + ' _long'
47
- return func_name
48
-
49
-
50
- cdef void _fill_signature(signature_str, int * signature):
51
- for i in range (len (signature_str)):
52
- if signature_str[i] == ' f' :
53
- signature[i] = cnp.NPY_FLOAT
54
- elif signature_str[i] == ' d' :
55
- signature[i] = cnp.NPY_DOUBLE
56
- elif signature_str[i] == ' F' :
57
- signature[i] = cnp.NPY_CFLOAT
58
- elif signature_str[i] == ' D' :
59
- signature[i] = cnp.NPY_CDOUBLE
60
- elif signature_str[i] == ' i' :
61
- signature[i] = cnp.NPY_INT
62
- elif signature_str[i] == ' l' :
63
- signature[i] = cnp.NPY_LONG
64
- elif signature_str[i] == ' ?' :
65
- signature[i] = cnp.NPY_BOOL
66
- else :
67
- raise ValueError (" _fill_signature: Unexpected type specified!" )
68
-
69
-
70
- cdef cnp.PyUFuncGenericFunction fooSaved
71
-
72
30
funcs_dict = {}
73
- cdef cnp.PyUFuncGenericFunction* originalFuncs
74
31
32
+ ctypedef struct function_info:
33
+ cnp.PyUFuncGenericFunction np_function
34
+ cnp.PyUFuncGenericFunction mkl_function
35
+ int * signature
75
36
76
- cdef c_do_patch():
77
- cdef int res
37
+ cdef function_info* functions
78
38
79
- global originalFuncs
39
+ def fill_functions ():
40
+ global functions
80
41
81
42
umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
82
- func_number = 0
43
+ funcs_count = 0
83
44
for umath in umaths:
84
45
mkl_umath = getattr (mu, umath)
85
46
types = mkl_umath.types
86
47
for type in types:
48
+ funcs_count = funcs_count + 1
49
+
50
+ functions = < function_info * > malloc(funcs_count * sizeof(function_info))
51
+
52
+ func_number = 0
53
+ for umath in umaths:
54
+ mkl_umath = getattr (mu, umath)
55
+ np_umath = getattr (nu, umath)
56
+ c_mkl_umath = < cnp.ufunc> mkl_umath
57
+ c_np_umath = < cnp.ufunc> np_umath
58
+ for type in mkl_umath.types:
59
+ np_index = np_umath.types.index(type )
60
+ functions[func_number].np_function = c_np_umath.functions[np_index]
61
+ mkl_index = mkl_umath.types.index(type )
62
+ functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
63
+
64
+ nargs = c_mkl_umath.nargs
65
+ functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
66
+ for i in range (nargs):
67
+ functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
68
+
87
69
funcs_dict[(umath, type )] = func_number
88
70
func_number = func_number + 1
89
- originalFuncs = < cnp.PyUFuncGenericFunction * > malloc(len (funcs_dict) * sizeof(cnp.PyUFuncGenericFunction))
71
+
72
+
73
+ fill_functions()
74
+
75
+ cdef c_do_patch():
76
+ cdef int res
77
+ cdef cnp.PyUFuncGenericFunction temp
78
+ cdef cnp.PyUFuncGenericFunction function
79
+ cdef int * signature
80
+
81
+ global functions
90
82
91
83
for func in funcs_dict:
92
- umath = func[0 ]
93
- type = func[1 ]
94
- np_umath = getattr (nu, umath)
95
- signature_str = type .replace(' ->' , ' ' )
96
- signature = < int * > malloc(len (signature_str) * sizeof(int ))
97
- _fill_signature(signature_str, signature)
98
- ufunc_name = _get_func_name(umath, type )
99
- ufunc = get_func_by_name(str .encode(ufunc_name))
100
- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, ufunc, signature, & (originalFuncs[funcs_dict[func]]))
101
- free(signature)
84
+ np_umath = getattr (nu, func[0 ])
85
+ index = funcs_dict[func]
86
+ function = functions[index].mkl_function
87
+ signature = functions[index].signature
88
+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
102
89
103
90
104
91
cdef c_do_unpatch():
105
92
cdef int res
106
93
cdef cnp.PyUFuncGenericFunction temp
94
+ cdef cnp.PyUFuncGenericFunction function
95
+ cdef int * signature
107
96
108
- global originalFuncs
97
+ global functions
109
98
110
99
for func in funcs_dict:
111
- umath = func[0 ]
112
- type = func[1 ]
113
- np_umath = getattr (nu, umath)
114
- signature_str = type .replace(' ->' , ' ' )
115
- signature = < int * > malloc(len (signature_str) * sizeof(int ))
116
- _fill_signature(signature_str, signature)
117
- res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, originalFuncs[funcs_dict[(umath, type )]], signature, & temp)
118
- free(signature)
119
- free(originalFuncs)
100
+ np_umath = getattr (nu, func[0 ])
101
+ index = funcs_dict[func]
102
+ function = functions[index].np_function
103
+ signature = functions[index].signature
104
+ res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
120
105
121
106
122
107
def do_patch ():
0 commit comments