Skip to content

Commit d3b0c2c

Browse files
Do not use of global dftiCache
Replaced use of global variable to store cached DFTI descriptor in favor of using Thread Local Storage to store it.
1 parent b94dc80 commit d3b0c2c

File tree

3 files changed

+256
-203
lines changed

3 files changed

+256
-203
lines changed

mkl_fft/_pydfti.pyx

Lines changed: 113 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,52 @@ except ImportError:
3434
from numpy.core._multiarray_tests import internal_overlap
3535

3636
from libc.string cimport memcpy
37+
cimport cpython.pycapsule
38+
from cpython.exc cimport (PyErr_Occurred, PyErr_Clear)
39+
from cpython.mem cimport (PyMem_Malloc, PyMem_Free)
3740

3841
from threading import Lock
42+
from threading import local as threading_local
3943
_lock = Lock()
4044

45+
# thread-local storage
46+
_tls = threading_local()
47+
48+
cdef const char *capsule_name = "dfti_cache"
49+
50+
cdef void _capsule_destructor(object caps):
51+
cdef DftiCache *_cache = NULL
52+
cdef int status = 0
53+
if (caps is None):
54+
print("Nothing to destroy")
55+
return
56+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(caps, capsule_name)
57+
status = _free_dfti_cache(_cache)
58+
PyMem_Free(_cache)
59+
if (status != 0):
60+
raise ValueError("Internal Error: Freeing DFTI Cache returned with error = {}".format(status))
61+
62+
63+
def _tls_dfti_cache_capsule():
64+
cdef DftiCache *_cache_struct
65+
66+
init = getattr(_tls, 'initialized', None)
67+
if (init is None):
68+
_cache_struct = <DftiCache *> PyMem_Malloc(sizeof(DftiCache));
69+
# important to initialized
70+
_cache_struct.initialized = 0
71+
_cache_struct.hand = NULL
72+
_tls.initialized = True
73+
_tls.capsule = cpython.pycapsule.PyCapsule_New(<void *>_cache_struct, capsule_name, &_capsule_destructor)
74+
capsule = getattr(_tls, 'capsule', None)
75+
if (not cpython.pycapsule.PyCapsule_IsValid(capsule, capsule_name)):
76+
raise ValueError("Internal Error: invalid capsule stored in TLS")
77+
return capsule
78+
79+
4180
cdef extern from "Python.h":
4281
ctypedef int size_t
4382

44-
void* PyMem_Malloc(size_t n)
45-
void PyMem_Free(void* buf)
46-
47-
int PyErr_Occurred()
48-
void PyErr_Clear()
4983
long PyInt_AsLong(object ob)
5084
int PyObject_HasAttrString(object, char*)
5185

@@ -58,32 +92,36 @@ cdef extern from *:
5892
object PyArray_BASE(cnp.ndarray)
5993

6094
cdef extern from "src/mklfft.h":
61-
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int)
62-
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int)
63-
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
64-
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray)
65-
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
66-
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray)
67-
68-
int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int)
69-
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int)
70-
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
71-
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray)
72-
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int)
73-
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray)
74-
75-
int double_mkl_rfft_in(cnp.ndarray, int, int)
76-
int double_mkl_irfft_in(cnp.ndarray, int, int)
77-
int float_mkl_rfft_in(cnp.ndarray, int, int)
78-
int float_mkl_irfft_in(cnp.ndarray, int, int)
79-
80-
int double_double_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray)
81-
int double_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
82-
int float_float_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray)
83-
int float_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
84-
85-
int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
86-
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray)
95+
cdef struct DftiCache:
96+
void * hand
97+
int initialized
98+
int _free_dfti_cache(DftiCache *)
99+
int cdouble_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
100+
int cfloat_mkl_fft1d_in(cnp.ndarray, int, int, DftiCache*)
101+
int float_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
102+
int cfloat_cfloat_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
103+
int double_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
104+
int cdouble_cdouble_mkl_fft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
105+
106+
int cdouble_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
107+
int cfloat_mkl_ifft1d_in(cnp.ndarray, int, int, DftiCache*)
108+
int float_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
109+
int cfloat_cfloat_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarra, DftiCache*)
110+
int double_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, int, DftiCache*)
111+
int cdouble_cdouble_mkl_ifft1d_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
112+
113+
int double_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
114+
int double_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
115+
int float_mkl_rfft_in(cnp.ndarray, int, int, DftiCache*)
116+
int float_mkl_irfft_in(cnp.ndarray, int, int, DftiCache*)
117+
118+
int double_double_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
119+
int double_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
120+
int float_float_mkl_rfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
121+
int float_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
122+
123+
int cdouble_double_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
124+
int cfloat_float_mkl_irfft_out(cnp.ndarray, int, int, cnp.ndarray, DftiCache*)
87125

88126
int cdouble_cdouble_mkl_fftnd_in(cnp.ndarray)
89127
int cdouble_cdouble_mkl_ifftnd_in(cnp.ndarray)
@@ -268,6 +306,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
268306
cdef int ALL_HARMONICS = 1
269307
cdef char * c_error_msg = NULL
270308
cdef bytes py_error_msg
309+
cdef DftiCache *_cache
271310

272311
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
273312
&axis_, &n_, &in_place, &xnd, &dir_, 0)
@@ -296,16 +335,18 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
296335

297336
if in_place:
298337
with _lock:
338+
_cache_capsule = _tls_dfti_cache_capsule()
339+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
299340
if x_type is cnp.NPY_CDOUBLE:
300341
if dir_ < 0:
301-
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_)
342+
status = cdouble_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
302343
else:
303-
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_)
344+
status = cdouble_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
304345
elif x_type is cnp.NPY_CFLOAT:
305346
if dir_ < 0:
306-
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_)
347+
status = cfloat_mkl_ifft1d_in(x_arr, n_, <int> axis_, _cache)
307348
else:
308-
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_)
349+
status = cfloat_mkl_fft1d_in(x_arr, n_, <int> axis_, _cache)
309350
else:
310351
status = 1
311352

@@ -328,36 +369,38 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
328369

329370
# call out-of-place FFT
330371
with _lock:
372+
_cache_capsule = _tls_dfti_cache_capsule()
373+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
331374
if f_type is cnp.NPY_CDOUBLE:
332375
if x_type is cnp.NPY_DOUBLE:
333376
if dir_ < 0:
334377
status = double_cdouble_mkl_ifft1d_out(
335-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
378+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
336379
else:
337380
status = double_cdouble_mkl_fft1d_out(
338-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
381+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
339382
elif x_type is cnp.NPY_CDOUBLE:
340383
if dir_ < 0:
341384
status = cdouble_cdouble_mkl_ifft1d_out(
342-
x_arr, n_, <int> axis_, f_arr)
385+
x_arr, n_, <int> axis_, f_arr, _cache)
343386
else:
344387
status = cdouble_cdouble_mkl_fft1d_out(
345-
x_arr, n_, <int> axis_, f_arr)
388+
x_arr, n_, <int> axis_, f_arr, _cache)
346389
else:
347390
if x_type is cnp.NPY_FLOAT:
348391
if dir_ < 0:
349392
status = float_cfloat_mkl_ifft1d_out(
350-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
393+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
351394
else:
352395
status = float_cfloat_mkl_fft1d_out(
353-
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS)
396+
x_arr, n_, <int> axis_, f_arr, ALL_HARMONICS, _cache)
354397
elif x_type is cnp.NPY_CFLOAT:
355398
if dir_ < 0:
356399
status = cfloat_cfloat_mkl_ifft1d_out(
357-
x_arr, n_, <int> axis_, f_arr)
400+
x_arr, n_, <int> axis_, f_arr, _cache)
358401
else:
359402
status = cfloat_cfloat_mkl_fft1d_out(
360-
x_arr, n_, <int> axis_, f_arr)
403+
x_arr, n_, <int> axis_, f_arr, _cache)
361404

362405
if (status):
363406
c_error_msg = mkl_dfti_error(status)
@@ -388,6 +431,7 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
388431
cdef int x_type, status
389432
cdef char * c_error_msg = NULL
390433
cdef bytes py_error_msg
434+
cdef DftiCache *_cache
391435

392436
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
393437
&axis_, &n_, &in_place, &xnd, &dir_, 1)
@@ -414,16 +458,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
414458

415459
if in_place:
416460
with _lock:
461+
_cache_capsule = _tls_dfti_cache_capsule()
462+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
417463
if x_type is cnp.NPY_DOUBLE:
418464
if dir_ < 0:
419-
status = double_mkl_irfft_in(x_arr, n_, <int> axis_)
465+
status = double_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
420466
else:
421-
status = double_mkl_rfft_in(x_arr, n_, <int> axis_)
467+
status = double_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
422468
elif x_type is cnp.NPY_FLOAT:
423469
if dir_ < 0:
424-
status = float_mkl_irfft_in(x_arr, n_, <int> axis_)
470+
status = float_mkl_irfft_in(x_arr, n_, <int> axis_, _cache)
425471
else:
426-
status = float_mkl_rfft_in(x_arr, n_, <int> axis_)
472+
status = float_mkl_rfft_in(x_arr, n_, <int> axis_, _cache)
427473
else:
428474
status = 1
429475

@@ -444,16 +490,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
444490

445491
# call out-of-place FFT
446492
with _lock:
493+
_cache_capsule = _tls_dfti_cache_capsule()
494+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
447495
if x_type is cnp.NPY_DOUBLE:
448496
if dir_ < 0:
449-
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
497+
status = double_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
450498
else:
451-
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
499+
status = double_double_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
452500
else:
453501
if dir_ < 0:
454-
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
502+
status = float_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
455503
else:
456-
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr)
504+
status = float_float_mkl_rfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
457505

458506
if (status):
459507
c_error_msg = mkl_dfti_error(status)
@@ -479,6 +527,7 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
479527
cdef int direction = 1 # dummy, only used for the sake of arg-processing
480528
cdef char * c_error_msg = NULL
481529
cdef bytes py_error_msg
530+
cdef DftiCache *_cache
482531

483532
x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
484533
&axis_, &n_, &in_place, &xnd, &dir_, 1)
@@ -510,10 +559,14 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
510559
# call out-of-place FFT
511560
if x_type is cnp.NPY_FLOAT:
512561
with _lock:
513-
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
562+
_cache_capsule = _tls_dfti_cache_capsule()
563+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
564+
status = float_cfloat_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
514565
else:
515566
with _lock:
516-
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS)
567+
_cache_capsule = _tls_dfti_cache_capsule()
568+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
569+
status = double_cdouble_mkl_fft1d_out(x_arr, n_, <int> axis_, f_arr, HALF_HARMONICS, _cache)
517570

518571
if (status):
519572
c_error_msg = mkl_dfti_error(status)
@@ -553,6 +606,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
553606
cdef int direction = 1 # dummy, only used for the sake of arg-processing
554607
cdef char * c_error_msg = NULL
555608
cdef bytes py_error_msg
609+
cdef DftiCache *_cache
556610

557611
int_n = _is_integral(n)
558612
# nn gives the number elements along axis of the input that we use
@@ -592,10 +646,14 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
592646
# call out-of-place FFT
593647
if x_type is cnp.NPY_CFLOAT:
594648
with _lock:
595-
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
649+
_cache_capsule = _tls_dfti_cache_capsule()
650+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
651+
status = cfloat_float_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
596652
else:
597653
with _lock:
598-
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr)
654+
_cache_capsule = _tls_dfti_cache_capsule()
655+
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(_cache_capsule, capsule_name)
656+
status = cdouble_double_mkl_irfft_out(x_arr, n_, <int> axis_, f_arr, _cache)
599657

600658
if (status):
601659
c_error_msg = mkl_dfti_error(status)

0 commit comments

Comments
 (0)