Skip to content

Commit 351b255

Browse files
thread local storage with capsule
1 parent d2fca85 commit 351b255

File tree

1 file changed

+74
-19
lines changed

1 file changed

+74
-19
lines changed

mkl_umath/src/patch.pyx

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,59 @@ import numpy as np
2525

2626
from libc.stdlib cimport malloc, free
2727

28-
cnp.import_umath()
28+
cimport cpython.pycapsule
2929

30-
funcs_dict = {}
30+
cnp.import_umath()
3131

3232
ctypedef struct function_info:
3333
cnp.PyUFuncGenericFunction np_function
3434
cnp.PyUFuncGenericFunction mkl_function
3535
int* signature
3636

37-
cdef function_info* functions
37+
ctypedef struct functions_struct:
38+
int count
39+
function_info* functions
40+
41+
42+
cdef const char *capsule_name = "functions_cache"
43+
44+
45+
cdef void _capsule_destructor(object caps):
46+
cdef functions_struct* fs
47+
48+
if (caps is None):
49+
print("Nothing to destroy")
50+
return
51+
fs = <functions_struct *>cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
52+
for i in range(fs[0].count):
53+
free(fs[0].functions[i].signature)
54+
free(fs[0].functions)
55+
free(fs)
56+
57+
58+
from threading import local as threading_local
59+
_tls = threading_local()
60+
61+
62+
def _is_tls_initialized():
63+
return (getattr(_tls, 'initialized', None) is not None) and (_tls.initialized == True)
3864

39-
def fill_functions():
40-
global functions
65+
66+
def _initialize_tls():
67+
cdef functions_struct* fs
68+
cdef int funcs_count
69+
70+
_tls.functions_dict = {}
4171

4272
umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
4373
funcs_count = 0
4474
for umath in umaths:
4575
mkl_umath = getattr(mu, umath)
46-
types = mkl_umath.types
47-
for type in types:
48-
funcs_count = funcs_count + 1
76+
funcs_count = funcs_count + mkl_umath.ntypes
4977

50-
functions = <function_info *> malloc(funcs_count * sizeof(function_info))
78+
fs = <functions_struct *> malloc(sizeof(functions_struct))
79+
fs[0].count = funcs_count
80+
fs[0].functions = <function_info *> malloc(funcs_count * sizeof(function_info))
5181

5282
func_number = 0
5383
for umath in umaths:
@@ -57,28 +87,51 @@ def fill_functions():
5787
c_np_umath = <cnp.ufunc>np_umath
5888
for type in mkl_umath.types:
5989
np_index = np_umath.types.index(type)
60-
functions[func_number].np_function = c_np_umath.functions[np_index]
90+
fs[0].functions[func_number].np_function = c_np_umath.functions[np_index]
6191
mkl_index = mkl_umath.types.index(type)
62-
functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
92+
fs[0].functions[func_number].mkl_function = c_mkl_umath.functions[mkl_index]
6393

6494
nargs = c_mkl_umath.nargs
65-
functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
95+
fs[0].functions[func_number].signature = <int *> malloc(nargs * sizeof(int))
6696
for i in range(nargs):
67-
functions[func_number].signature[i] = c_mkl_umath.types[mkl_index*nargs + i]
97+
fs[0].functions[func_number].signature[i] = c_mkl_umath.types[mkl_index*nargs + i]
6898

69-
funcs_dict[(umath, type)] = func_number
99+
_tls.functions_dict[(umath, type)] = func_number
70100
func_number = func_number + 1
71101

102+
_tls.functions_capsule = cpython.pycapsule.PyCapsule_New(<void *>fs, capsule_name, &_capsule_destructor)
103+
104+
_tls.initialized = True
105+
106+
107+
def _get_func_dict():
108+
if not _is_tls_initialized():
109+
_initialize_tls()
110+
return _tls.functions_dict
72111

73-
fill_functions()
74112

75-
cdef c_do_patch():
113+
cdef function_info* _get_functions():
114+
cdef function_info* functions
115+
cdef functions_struct* fs
116+
117+
if not _is_tls_initialized():
118+
_initialize_tls()
119+
120+
capsule = _tls.functions_capsule
121+
if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
122+
raise ValueError("Internal Error: invalid capsule stored in TLS")
123+
fs = <functions_struct *>cpython.pycapsule.PyCapsule_GetPointer(capsule, capsule_name)
124+
return fs[0].functions
125+
126+
127+
cdef void c_do_patch():
76128
cdef int res
77129
cdef cnp.PyUFuncGenericFunction temp
78130
cdef cnp.PyUFuncGenericFunction function
79131
cdef int* signature
80132

81-
global functions
133+
funcs_dict = _get_func_dict()
134+
functions = _get_functions()
82135

83136
for func in funcs_dict:
84137
np_umath = getattr(nu, func[0])
@@ -88,13 +141,14 @@ cdef c_do_patch():
88141
res = cnp.PyUFunc_ReplaceLoopBySignature(np_umath, function, signature, &temp)
89142

90143

91-
cdef c_do_unpatch():
144+
cdef void c_do_unpatch():
92145
cdef int res
93146
cdef cnp.PyUFuncGenericFunction temp
94147
cdef cnp.PyUFuncGenericFunction function
95148
cdef int* signature
96149

97-
global functions
150+
funcs_dict = _get_func_dict()
151+
functions = _get_functions()
98152

99153
for func in funcs_dict:
100154
np_umath = getattr(nu, func[0])
@@ -107,5 +161,6 @@ cdef c_do_unpatch():
107161
def do_patch():
108162
c_do_patch()
109163

164+
110165
def do_unpatch():
111166
c_do_unpatch()

0 commit comments

Comments
 (0)