@@ -34,18 +34,52 @@ except ImportError:
34
34
from numpy .core ._multiarray_tests import internal_overlap
35
35
36
36
from libc .string cimport memcpy
37
+ cimport cpython .pycapsule
38
+ from cpython .exc cimport (PyErr_Occurred , PyErr_Clear )
39
+ from cpython .mem cimport (PyMem_Malloc , PyMem_Free )
37
40
38
41
from threading import Lock
42
+ from threading import local as threading_local
39
43
_lock = Lock ()
40
44
45
+ # thread-local storage
46
+ _tls = threading_local ()
47
+
48
+ cdef const char * capsule_name = "dfti_cache"
49
+
50
+ cdef void _capsule_destructor (object caps ):
51
+ cdef DftiCache * _cache = NULL
52
+ cdef int status = 0
53
+ if (caps is None ):
54
+ print ("Nothing to destroy" )
55
+ return
56
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (caps , capsule_name )
57
+ status = _free_dfti_cache (_cache )
58
+ PyMem_Free (_cache )
59
+ if (status != 0 ):
60
+ raise ValueError ("Internal Error: Freeing DFTI Cache returned with error = {}" .format (status ))
61
+
62
+
63
+ def _tls_dfti_cache_capsule ():
64
+ cdef DftiCache * _cache_struct
65
+
66
+ init = getattr (_tls , 'initialized' , None )
67
+ if (init is None ):
68
+ _cache_struct = < DftiCache * > PyMem_Malloc (sizeof (DftiCache ));
69
+ # important to initialized
70
+ _cache_struct .initialized = 0
71
+ _cache_struct .hand = NULL
72
+ _tls .initialized = True
73
+ _tls .capsule = cpython .pycapsule .PyCapsule_New (< void * > _cache_struct , capsule_name , & _capsule_destructor )
74
+ capsule = getattr (_tls , 'capsule' , None )
75
+ if (not cpython .pycapsule .PyCapsule_IsValid (capsule , capsule_name )):
76
+ raise ValueError ("Internal Error: invalid capsule stored in TLS" )
77
+ return capsule
78
+
79
+
41
80
cdef extern from "Python.h" :
42
81
ctypedef int size_t
43
82
44
- void * PyMem_Malloc (size_t n )
45
- void PyMem_Free (void * buf )
46
-
47
- int PyErr_Occurred ()
48
- void PyErr_Clear ()
49
83
long PyInt_AsLong (object ob )
50
84
int PyObject_HasAttrString (object , char * )
51
85
@@ -58,32 +92,36 @@ cdef extern from *:
58
92
object PyArray_BASE (cnp .ndarray )
59
93
60
94
cdef extern from "src/mklfft.h" :
61
- int cdouble_mkl_fft1d_in (cnp .ndarray , int , int )
62
- int cfloat_mkl_fft1d_in (cnp .ndarray , int , int )
63
- int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
64
- int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
65
- int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
66
- int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray )
67
-
68
- int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int )
69
- int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int )
70
- int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
71
- int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
72
- int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int )
73
- int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray )
74
-
75
- int double_mkl_rfft_in (cnp .ndarray , int , int )
76
- int double_mkl_irfft_in (cnp .ndarray , int , int )
77
- int float_mkl_rfft_in (cnp .ndarray , int , int )
78
- int float_mkl_irfft_in (cnp .ndarray , int , int )
79
-
80
- int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
81
- int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
82
- int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray )
83
- int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
84
-
85
- int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
86
- int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray )
95
+ cdef struct DftiCache :
96
+ void * hand
97
+ int initialized
98
+ int _free_dfti_cache (DftiCache * )
99
+ int cdouble_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
100
+ int cfloat_mkl_fft1d_in (cnp .ndarray , int , int , DftiCache * )
101
+ int float_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
102
+ int cfloat_cfloat_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
103
+ int double_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
104
+ int cdouble_cdouble_mkl_fft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
105
+
106
+ int cdouble_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
107
+ int cfloat_mkl_ifft1d_in (cnp .ndarray , int , int , DftiCache * )
108
+ int float_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
109
+ int cfloat_cfloat_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarra , DftiCache * )
110
+ int double_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , int , DftiCache * )
111
+ int cdouble_cdouble_mkl_ifft1d_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
112
+
113
+ int double_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
114
+ int double_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
115
+ int float_mkl_rfft_in (cnp .ndarray , int , int , DftiCache * )
116
+ int float_mkl_irfft_in (cnp .ndarray , int , int , DftiCache * )
117
+
118
+ int double_double_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
119
+ int double_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
120
+ int float_float_mkl_rfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
121
+ int float_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
122
+
123
+ int cdouble_double_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
124
+ int cfloat_float_mkl_irfft_out (cnp .ndarray , int , int , cnp .ndarray , DftiCache * )
87
125
88
126
int cdouble_cdouble_mkl_fftnd_in (cnp .ndarray )
89
127
int cdouble_cdouble_mkl_ifftnd_in (cnp .ndarray )
@@ -268,6 +306,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
268
306
cdef int ALL_HARMONICS = 1
269
307
cdef char * c_error_msg = NULL
270
308
cdef bytes py_error_msg
309
+ cdef DftiCache * _cache
271
310
272
311
x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
273
312
& axis_ , & n_ , & in_place , & xnd , & dir_ , 0 )
@@ -296,16 +335,18 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
296
335
297
336
if in_place :
298
337
with _lock :
338
+ _cache_capsule = _tls_dfti_cache_capsule ()
339
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
299
340
if x_type is cnp .NPY_CDOUBLE :
300
341
if dir_ < 0 :
301
- status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
342
+ status = cdouble_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
302
343
else :
303
- status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
344
+ status = cdouble_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
304
345
elif x_type is cnp .NPY_CFLOAT :
305
346
if dir_ < 0 :
306
- status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ )
347
+ status = cfloat_mkl_ifft1d_in (x_arr , n_ , < int > axis_ , _cache )
307
348
else :
308
- status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ )
349
+ status = cfloat_mkl_fft1d_in (x_arr , n_ , < int > axis_ , _cache )
309
350
else :
310
351
status = 1
311
352
@@ -328,36 +369,38 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
328
369
329
370
# call out-of-place FFT
330
371
with _lock :
372
+ _cache_capsule = _tls_dfti_cache_capsule ()
373
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
331
374
if f_type is cnp .NPY_CDOUBLE :
332
375
if x_type is cnp .NPY_DOUBLE :
333
376
if dir_ < 0 :
334
377
status = double_cdouble_mkl_ifft1d_out (
335
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
378
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
336
379
else :
337
380
status = double_cdouble_mkl_fft1d_out (
338
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
381
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
339
382
elif x_type is cnp .NPY_CDOUBLE :
340
383
if dir_ < 0 :
341
384
status = cdouble_cdouble_mkl_ifft1d_out (
342
- x_arr , n_ , < int > axis_ , f_arr )
385
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
343
386
else :
344
387
status = cdouble_cdouble_mkl_fft1d_out (
345
- x_arr , n_ , < int > axis_ , f_arr )
388
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
346
389
else :
347
390
if x_type is cnp .NPY_FLOAT :
348
391
if dir_ < 0 :
349
392
status = float_cfloat_mkl_ifft1d_out (
350
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
393
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
351
394
else :
352
395
status = float_cfloat_mkl_fft1d_out (
353
- x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS )
396
+ x_arr , n_ , < int > axis_ , f_arr , ALL_HARMONICS , _cache )
354
397
elif x_type is cnp .NPY_CFLOAT :
355
398
if dir_ < 0 :
356
399
status = cfloat_cfloat_mkl_ifft1d_out (
357
- x_arr , n_ , < int > axis_ , f_arr )
400
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
358
401
else :
359
402
status = cfloat_cfloat_mkl_fft1d_out (
360
- x_arr , n_ , < int > axis_ , f_arr )
403
+ x_arr , n_ , < int > axis_ , f_arr , _cache )
361
404
362
405
if (status ):
363
406
c_error_msg = mkl_dfti_error (status )
@@ -388,6 +431,7 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
388
431
cdef int x_type , status
389
432
cdef char * c_error_msg = NULL
390
433
cdef bytes py_error_msg
434
+ cdef DftiCache * _cache
391
435
392
436
x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
393
437
& axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -414,16 +458,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
414
458
415
459
if in_place :
416
460
with _lock :
461
+ _cache_capsule = _tls_dfti_cache_capsule ()
462
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
417
463
if x_type is cnp .NPY_DOUBLE :
418
464
if dir_ < 0 :
419
- status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ )
465
+ status = double_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
420
466
else :
421
- status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ )
467
+ status = double_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
422
468
elif x_type is cnp .NPY_FLOAT :
423
469
if dir_ < 0 :
424
- status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ )
470
+ status = float_mkl_irfft_in (x_arr , n_ , < int > axis_ , _cache )
425
471
else :
426
- status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ )
472
+ status = float_mkl_rfft_in (x_arr , n_ , < int > axis_ , _cache )
427
473
else :
428
474
status = 1
429
475
@@ -444,16 +490,18 @@ def _rrfft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1):
444
490
445
491
# call out-of-place FFT
446
492
with _lock :
493
+ _cache_capsule = _tls_dfti_cache_capsule ()
494
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
447
495
if x_type is cnp .NPY_DOUBLE :
448
496
if dir_ < 0 :
449
- status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
497
+ status = double_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
450
498
else :
451
- status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
499
+ status = double_double_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
452
500
else :
453
501
if dir_ < 0 :
454
- status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
502
+ status = float_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
455
503
else :
456
- status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr )
504
+ status = float_float_mkl_rfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
457
505
458
506
if (status ):
459
507
c_error_msg = mkl_dfti_error (status )
@@ -479,6 +527,7 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
479
527
cdef int direction = 1 # dummy, only used for the sake of arg-processing
480
528
cdef char * c_error_msg = NULL
481
529
cdef bytes py_error_msg
530
+ cdef DftiCache * _cache
482
531
483
532
x_arr = __process_arguments (x , n , axis , overwrite_arg , direction ,
484
533
& axis_ , & n_ , & in_place , & xnd , & dir_ , 1 )
@@ -510,10 +559,14 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
510
559
# call out-of-place FFT
511
560
if x_type is cnp .NPY_FLOAT :
512
561
with _lock :
513
- status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
562
+ _cache_capsule = _tls_dfti_cache_capsule ()
563
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
564
+ status = float_cfloat_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
514
565
else :
515
566
with _lock :
516
- status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS )
567
+ _cache_capsule = _tls_dfti_cache_capsule ()
568
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
569
+ status = double_cdouble_mkl_fft1d_out (x_arr , n_ , < int > axis_ , f_arr , HALF_HARMONICS , _cache )
517
570
518
571
if (status ):
519
572
c_error_msg = mkl_dfti_error (status )
@@ -553,6 +606,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
553
606
cdef int direction = 1 # dummy, only used for the sake of arg-processing
554
607
cdef char * c_error_msg = NULL
555
608
cdef bytes py_error_msg
609
+ cdef DftiCache * _cache
556
610
557
611
int_n = _is_integral (n )
558
612
# nn gives the number elements along axis of the input that we use
@@ -592,10 +646,14 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False):
592
646
# call out-of-place FFT
593
647
if x_type is cnp .NPY_CFLOAT :
594
648
with _lock :
595
- status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
649
+ _cache_capsule = _tls_dfti_cache_capsule ()
650
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
651
+ status = cfloat_float_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
596
652
else :
597
653
with _lock :
598
- status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr )
654
+ _cache_capsule = _tls_dfti_cache_capsule ()
655
+ _cache = < DftiCache * > cpython .pycapsule .PyCapsule_GetPointer (_cache_capsule , capsule_name )
656
+ status = cdouble_double_mkl_irfft_out (x_arr , n_ , < int > axis_ , f_arr , _cache )
599
657
600
658
if (status ):
601
659
c_error_msg = mkl_dfti_error (status )
0 commit comments