@@ -25,29 +25,59 @@ import numpy as np
25
25
26
26
from libc.stdlib cimport malloc, free
27
27
28
- cnp.import_umath()
28
+ cimport cpython.pycapsule
29
29
30
- funcs_dict = {}
30
+ cnp.import_umath()
31
31
32
32
ctypedef struct function_info:
33
33
cnp.PyUFuncGenericFunction np_function
34
34
cnp.PyUFuncGenericFunction mkl_function
35
35
int * signature
36
36
37
- cdef function_info* functions
37
+ ctypedef struct functions_struct:
38
+ int count
39
+ function_info* functions
40
+
41
+
42
+ cdef const char * capsule_name = " functions_cache"
43
+
44
+
45
+ cdef void _capsule_destructor(object caps):
46
+ cdef functions_struct* fs
47
+
48
+ if (caps is None ):
49
+ print (" Nothing to destroy" )
50
+ return
51
+ fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52
+ for i in range (fs[0 ].count):
53
+ free(fs[0 ].functions[i].signature)
54
+ free(fs[0 ].functions)
55
+ free(fs)
56
+
57
+
58
+ from threading import local as threading_local
59
+ _tls = threading_local()
60
+
61
+
62
+ def _is_tls_initialized ():
63
+ return (getattr (_tls, ' initialized' , None ) is not None ) and (_tls.initialized == True )
38
64
39
- def fill_functions ():
40
- global functions
65
+
66
+ def _initialize_tls ():
67
+ cdef functions_struct* fs
68
+ cdef int funcs_count
69
+
70
+ _tls.functions_dict = {}
41
71
42
72
umaths = [i for i in dir (mu) if isinstance (getattr (mu, i), np.ufunc)]
43
73
funcs_count = 0
44
74
for umath in umaths:
45
75
mkl_umath = getattr (mu, umath)
46
- types = mkl_umath.types
47
- for type in types:
48
- funcs_count = funcs_count + 1
76
+ funcs_count = funcs_count + mkl_umath.ntypes
49
77
50
- functions = < function_info * > malloc(funcs_count * sizeof(function_info))
78
+ fs = < functions_struct * > malloc(sizeof(functions_struct))
79
+ fs[0 ].count = funcs_count
80
+ fs[0 ].functions = < function_info * > malloc(funcs_count * sizeof(function_info))
51
81
52
82
func_number = 0
53
83
for umath in umaths:
@@ -57,28 +87,51 @@ def fill_functions():
57
87
c_np_umath = < cnp.ufunc> np_umath
58
88
for type in mkl_umath.types:
59
89
np_index = np_umath.types.index(type )
60
- functions[func_number].np_function = c_np_umath.functions[np_index]
90
+ fs[ 0 ]. functions[func_number].np_function = c_np_umath.functions[np_index]
61
91
mkl_index = mkl_umath.types.index(type )
62
- functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
92
+ fs[ 0 ]. functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
63
93
64
94
nargs = c_mkl_umath.nargs
65
- functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
95
+ fs[ 0 ]. functions[func_number].signature = < int * > malloc(nargs * sizeof(int ))
66
96
for i in range (nargs):
67
- functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
97
+ fs[ 0 ]. functions[func_number].signature[i] = c_mkl_umath.types[mkl_index* nargs + i]
68
98
69
- funcs_dict [(umath, type )] = func_number
99
+ _tls.functions_dict [(umath, type )] = func_number
70
100
func_number = func_number + 1
71
101
102
+ _tls.functions_capsule = cpython.pycapsule.PyCapsule_New(< void * > fs, capsule_name, & _capsule_destructor)
103
+
104
+ _tls.initialized = True
105
+
106
+
107
+ def _get_func_dict ():
108
+ if not _is_tls_initialized():
109
+ _initialize_tls()
110
+ return _tls.functions_dict
72
111
73
- fill_functions()
74
112
75
- cdef c_do_patch():
113
+ cdef function_info* _get_functions():
114
+ cdef function_info* functions
115
+ cdef functions_struct* fs
116
+
117
+ if not _is_tls_initialized():
118
+ _initialize_tls()
119
+
120
+ capsule = _tls.functions_capsule
121
+ if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122
+ raise ValueError (" Internal Error: invalid capsule stored in TLS" )
123
+ fs = < functions_struct * > cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124
+ return fs[0 ].functions
125
+
126
+
127
+ cdef void c_do_patch():
76
128
cdef int res
77
129
cdef cnp.PyUFuncGenericFunction temp
78
130
cdef cnp.PyUFuncGenericFunction function
79
131
cdef int * signature
80
132
81
- global functions
133
+ funcs_dict = _get_func_dict()
134
+ functions = _get_functions()
82
135
83
136
for func in funcs_dict:
84
137
np_umath = getattr (nu, func[0 ])
@@ -88,13 +141,14 @@ cdef c_do_patch():
88
141
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, & temp)
89
142
90
143
91
- cdef c_do_unpatch():
144
+ cdef void c_do_unpatch():
92
145
cdef int res
93
146
cdef cnp.PyUFuncGenericFunction temp
94
147
cdef cnp.PyUFuncGenericFunction function
95
148
cdef int * signature
96
149
97
- global functions
150
+ funcs_dict = _get_func_dict()
151
+ functions = _get_functions()
98
152
99
153
for func in funcs_dict:
100
154
np_umath = getattr (nu, func[0 ])
@@ -107,5 +161,6 @@ cdef c_do_unpatch():
107
161
def do_patch ():
108
162
c_do_patch()
109
163
164
+
110
165
def do_unpatch ():
111
166
c_do_unpatch()
0 commit comments