Skip to content

Commit bf7560e

Browse files
Add class DftiBackend suitable for scipy.fft.register_backend
1 parent efae8da commit bf7560e

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,28 @@ def get_max_threads_count(self):
6767

6868

6969
_hardware_counts = _cpu_max_threads_count()
70-
70+
7171

7272
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
7373
'rfft', 'irfft', 'rfft2', 'irfft2', 'rfftn', 'irfftn',
7474
'hfft', 'ihfft', 'hfft2', 'ihfft2', 'hfftn', 'ihfftn',
7575
'dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', 'idstn',
7676
'fftshift', 'ifftshift', 'fftfreq', 'rfftfreq', 'get_workers',
77-
'set_workers', 'next_fast_len']
77+
'set_workers', 'next_fast_len', 'DftiBackend']
7878

79-
__ua_domain__ = 'numpy.scipy.fft'
80-
__implemented = dict()
8179

82-
def __ua_function__(method, args, kwargs):
83-
"""Fetch registered UA function."""
84-
fn = __implemented.get(method, None)
85-
if fn is None:
86-
return NotImplemented
87-
return fn(*args, **kwargs)
80+
class DftiBackend:
81+
__ua_domain__ = "numpy.scipy.fft"
82+
@staticmethod
83+
def __ua_function__(method, args, kwargs):
84+
"""Fetch registered UA function."""
85+
fn = __implemented.get(method, None)
86+
if fn is None:
87+
return NotImplemented
88+
return fn(*args, **kwargs)
89+
90+
91+
__implemented = dict()
8892

8993

9094
def _implements(scipy_func):

0 commit comments

Comments
 (0)