Skip to content

Commit dbf7ea1

Browse files
committed
Add FFT methods to utils.interface
1 parent e285774 commit dbf7ea1

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

pysteps/utils/fft.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,28 @@
66
# TODO: Caching and multithreading are currently disabled because they give
77
# a segfault with dask.
88
pyfftw.interfaces.cache.disable()
9-
pyfftw_kwargs = {"threads":1, "planner_effort":"FFTW_ESTIMATE"}
109
pyfftw_imported = True
1110
except ImportError:
1211
pyfftw_imported = False
13-
try:
14-
import scipy.fftpack as scipy_fft
15-
scipy_fft_kwargs = {}
16-
except ImportError:
17-
scipy_imported = False
18-
try:
19-
import numpy.fft as numpy_fft
20-
numpy_fft_kwargs = {}
21-
except ImportError:
22-
numpy_imported = False
12+
import scipy.fftpack as scipy_fft
13+
import numpy.fft as numpy_fft
14+
15+
from pysteps.exceptions import MissingOptionalDependency
16+
17+
# use numpy implementation of rfft2/irfft2 because they have not been
18+
# implemented in scipy.fftpack
19+
scipy_fft.rfft2 = numpy_fft.rfft2
20+
scipy_fft.irfft2 = numpy_fft.irfft2
2321

2422
def get_method(name):
23+
"""Return a callable function for the FFT method corresponding to the given name."""
2524
if name == "numpy":
26-
return numpy_fft
25+
return numpy_fft,{}
2726
elif name == "scipy":
28-
return scipy_fft
27+
return scipy_fft,{}
2928
elif name == "pyfftw":
30-
return pyfftw_fft
29+
if not pyfftw_imported:
30+
raise MissingOptionalDependency("pyfftw is required but it is not installed")
31+
return pyfftw_fft,{"threads":1, "planner_effort":"FFTW_ESTIMATE"}
3132
else:
3233
raise ValueError("unknown method %s, the available methods are 'numpy', 'scipy' and 'pyfftw'" % name)

pysteps/utils/interface.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from . import conversion
33
from . import transformation
44
from . import dimension
5+
from . import fft
56

67
def get_method(name):
78
"""Return a callable function for the utility method corresponding to the
8-
given name.\n\
9+
given name. For the FFT methods, the return value is a two-element tuple
10+
containing the function and a dictionary of keyword arguments.\n\
911
1012
Conversion methods:
1113
@@ -50,6 +52,18 @@ def get_method(name):
5052
| upscale | upscale the field |
5153
+-------------------+--------------------------------------------------------+
5254
55+
FFT methods (wrappers to different implementations):
56+
57+
+-------------------+--------------------------------------------------------+
58+
| Name | Description |
59+
+===================+========================================================+
60+
| numpy_fft | numpy.fft |
61+
+-------------------+--------------------------------------------------------+
62+
| scipy_fft | scipy.fftpack |
63+
+-------------------+--------------------------------------------------------+
64+
| pyfftw_fft | pyfftw.interfaces.numpy_fft |
65+
+-------------------+--------------------------------------------------------+
66+
5367
"""
5468

5569
if name is None:
@@ -82,6 +96,10 @@ def donothing(R, metadata, *args, **kwargs):
8296
methods_objects["clip"] = dimension.clip_domain
8397
methods_objects["square"] = dimension.square_domain
8498
methods_objects["upscale"] = dimension.aggregate_fields_space
99+
# FFT methods
100+
methods_objects["numpy_fft"] = fft.get_method("numpy")
101+
methods_objects["scipy_fft"] = fft.get_method("scipy")
102+
methods_objects["pyfftw_fft"] = fft.get_method("pyfftw")
85103

86104
try:
87105
return methods_objects[name]

0 commit comments

Comments
 (0)