Skip to content

Commit 727a6fd

Browse files
Added auto-loading of interfaces
In scipy fft return NotImplemented for complex256 and double128 inputs.
1 parent 0ed11bf commit 727a6fd

File tree

6 files changed

+114
-25
lines changed

6 files changed

+114
-25
lines changed

mkl_fft/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
rfft_numpy, irfft_numpy, rfftn_numpy, irfftn_numpy)
2929

3030
from ._version import __version__
31+
import mkl_fft.interfaces
3132

3233
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn', 'rfft', 'irfft',
33-
'rfft_numpy', 'irfft_numpy', 'rfftn_numpy', 'irfftn_numpy']
34+
'rfft_numpy', 'irfft_numpy', 'rfftn_numpy', 'irfftn_numpy', 'interfaces']

mkl_fft/_scipy_fft_backend.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ def __exit__(self, *args):
162162

163163

164164
def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
165-
x = _float_utils.__upcast_float16_array(a)
165+
try:
166+
x = _float_utils.__upcast_float16_array(a)
167+
except ValueError:
168+
return NotImplemented
166169
with Workers(workers):
167170
output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
168171
if _unitary(norm):
@@ -171,7 +174,10 @@ def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
171174

172175

173176
def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
174-
x = _float_utils.__upcast_float16_array(a)
177+
try:
178+
x = _float_utils.__upcast_float16_array(a)
179+
except ValueError:
180+
return NotImplemented
175181
with Workers(workers):
176182
output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
177183
if _unitary(norm):
@@ -180,7 +186,10 @@ def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
180186

181187

182188
def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
183-
x = _float_utils.__upcast_float16_array(a)
189+
try:
190+
x = _float_utils.__upcast_float16_array(a)
191+
except ValueError:
192+
return NotImplemented
184193
with Workers(workers):
185194
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
186195
if _unitary(norm):
@@ -192,7 +201,10 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
192201

193202

194203
def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
195-
x = _float_utils.__upcast_float16_array(a)
204+
try:
205+
x = _float_utils.__upcast_float16_array(a)
206+
except ValueError:
207+
return NotImplemented
196208
with Workers(workers):
197209
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
198210
if _unitary(norm):
@@ -205,7 +217,10 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
205217

206218

207219
def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
208-
x = _float_utils.__upcast_float16_array(a)
220+
try:
221+
x = _float_utils.__upcast_float16_array(a)
222+
except ValueError:
223+
return NotImplemented
209224
with Workers(workers):
210225
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
211226
if _unitary(norm):
@@ -218,7 +233,10 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
218233

219234

220235
def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
221-
x = _float_utils.__upcast_float16_array(a)
236+
try:
237+
x = _float_utils.__upcast_float16_array(a)
238+
except ValueError:
239+
return NotImplemented
222240
with Workers(workers):
223241
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
224242
if _unitary(norm):
@@ -231,7 +249,10 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
231249

232250

233251
def rfft(a, n=None, axis=-1, norm=None, workers=None):
234-
x = _float_utils.__upcast_float16_array(a)
252+
try:
253+
x = _float_utils.__upcast_float16_array(a)
254+
except ValueError:
255+
return NotImplemented
235256
unitary = _unitary(norm)
236257
x = _float_utils.__downcast_float128_array(x)
237258
if unitary and n is None:
@@ -245,8 +266,10 @@ def rfft(a, n=None, axis=-1, norm=None, workers=None):
245266

246267

247268
def irfft(a, n=None, axis=-1, norm=None, workers=None):
248-
x = _float_utils.__upcast_float16_array(a)
249-
x = _float_utils.__downcast_float128_array(x)
269+
try:
270+
x = _float_utils.__upcast_float16_array(a)
271+
except ValueError:
272+
return NotImplemented
250273
with Workers(workers):
251274
output = _pydfti.irfft_numpy(x, n=n, axis=axis)
252275
if _unitary(norm):
@@ -255,21 +278,27 @@ def irfft(a, n=None, axis=-1, norm=None, workers=None):
255278

256279

257280
def rfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
258-
x = _float_utils.__upcast_float16_array(a)
259-
x = _float_utils.__downcast_float128_array(a)
281+
try:
282+
x = _float_utils.__upcast_float16_array(a)
283+
except ValueError:
284+
return NotImplemented
260285
return rfftn(x, s, axes, norm, workers)
261286

262287

263288
def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
264-
x = _float_utils.__upcast_float16_array(a)
265-
x = _float_utils.__downcast_float128_array(x)
289+
try:
290+
x = _float_utils.__upcast_float16_array(a)
291+
except ValueError:
292+
return NotImplemented
266293
return irfftn(x, s, axes, norm, workers)
267294

268295

269296
def rfftn(a, s=None, axes=None, norm=None, workers=None):
270297
unitary = _unitary(norm)
271-
x = _float_utils.__upcast_float16_array(a)
272-
x = _float_utils.__downcast_float128_array(x)
298+
try:
299+
x = _float_utils.__upcast_float16_array(a)
300+
except ValueError:
301+
return NotImplemented
273302
if unitary:
274303
x = asarray(x)
275304
s, axes = _cook_nd_args(x, s, axes)
@@ -282,8 +311,10 @@ def rfftn(a, s=None, axes=None, norm=None, workers=None):
282311

283312

284313
def irfftn(a, s=None, axes=None, norm=None, workers=None):
285-
x = _float_utils.__upcast_float16_array(a)
286-
x = _float_utils.__downcast_float128_array(x)
314+
try:
315+
x = _float_utils.__upcast_float16_array(a)
316+
except ValueError:
317+
return NotImplemented
287318
with Workers(workers):
288319
output = _pydfti.irfftn_numpy(x, s, axes)
289320
if _unitary(norm):

mkl_fft/interfaces/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,27 @@
1+
# Copyright (c) 2017-2023, Intel Corporation
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# * Redistributions of source code must retain the above copyright notice,
7+
# this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of Intel Corporation nor the names of its contributors
12+
# may be used to endorse or promote products derived from this software
13+
# without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
19+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25+
126
from .. import _numpy_fft as numpy_fft
227
from .. import _scipy_fft_backend as scipy_fft

mkl_fft/tests/test_fft1d.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from __future__ import division, absolute_import, print_function
28-
2927
import numpy as np
3028
from numpy.testing import (
3129
TestCase, run_module_suite, assert_, assert_raises, assert_equal,
@@ -127,7 +125,7 @@ def test_vector5(self):
127125
f1 = mkl_fft.fft(x, overwrite_x=True)
128126
f2 = mkl_fft.fft(self.xz1[::-2])
129127
assert_(np.allclose(f1,f2))
130-
128+
131129
def test_vector6(self):
132130
"fft in place"
133131
x = self.xz1.copy()
@@ -137,8 +135,8 @@ def test_vector6(self):
137135
x = self.xz1.copy()
138136
f1 = mkl_fft.fft(x[::-2], overwrite_x=True)
139137
assert_( not np.allclose(x, self.xz1) ) # this is also in-place
140-
assert_( np.allclose(x[-2::-2], self.xz1[-2::-2]) )
141-
assert_( np.allclose(x[-1::-2], f1) )
138+
assert_( np.allclose(x[-2::-2], self.xz1[-2::-2]) )
139+
assert_( np.allclose(x[-1::-2], f1) )
142140

143141
def test_vector7(self):
144142
"fft of real array is the same as fft of its complex cast"

mkl_fft/tests/test_fftnd.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from __future__ import division, absolute_import, print_function
28-
2927
import numpy as np
3028
from numpy.testing import (
3129
TestCase, run_module_suite, assert_, assert_raises, assert_equal,

mkl_fft/tests/test_interfaces.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017-2023, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import mkl_fft.interfaces as mfi
28+
import pytest
29+
30+
31+
def test_interfaces_has_numpy():
32+
assert hasattr(mfi, 'numpy_fft')
33+
34+
35+
def test_interfaces_has_scipy():
36+
assert hasattr(mfi, 'scipy_fft')

0 commit comments

Comments
 (0)