Skip to content

Commit 6a468e8

Browse files
committed
Fix MissingOptionalDependency that was always thrown when pyfftw is not installed and also remove redundant utils.fft.get_method
1 parent bbb44ca commit 6a468e8

File tree

4 files changed

+32
-46
lines changed

4 files changed

+32
-46
lines changed

pysteps/cascade/decomposition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"""
2424

2525
import numpy as np
26-
from .. utils import fft as fft_module
26+
from .. import utils
2727

2828
def decomposition_fft(X, filter, **kwargs):
2929
"""Decompose a 2d input field into multiple spatial scales by using the Fast
@@ -55,7 +55,7 @@ def decomposition_fft(X, filter, **kwargs):
5555
"""
5656
fft = kwargs.get("fft_method", "numpy")
5757
if type(fft) == str:
58-
fft,fft_kwargs = fft_module.get_method(fft)
58+
fft,fft_kwargs = utils.get_method(fft)
5959
else:
6060
fft,fft_kwargs = fft
6161
MASK = kwargs.get("MASK", None)

pysteps/noise/fftgenerators.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import numpy as np
3333
from scipy import optimize
3434
from .. import utils
35-
from .. utils import fft as fft_module
35+
from .. import utils
3636

3737
def initialize_param_2d_fft_filter(X, **kwargs):
3838
"""Takes one ore more 2d input fields, fits two spectral slopes, beta1 and beta2,
@@ -88,7 +88,7 @@ def initialize_param_2d_fft_filter(X, **kwargs):
8888
doplot = kwargs.get('doplot', False)
8989
fft = kwargs.get("fft_method", "numpy")
9090
if type(fft) == str:
91-
fft,fft_kwargs = fft_module.get_method(fft)
91+
fft,fft_kwargs = utils.get_method(fft)
9292
else:
9393
fft,fft_kwargs = fft
9494

@@ -216,7 +216,7 @@ def initialize_nonparam_2d_fft_filter(X, **kwargs):
216216
use_full_fft = kwargs.get('use_full_fft', False)
217217
fft = kwargs.get("fft_method", "numpy")
218218
if type(fft) == str:
219-
fft,fft_kwargs = fft_module.get_method(fft)
219+
fft,fft_kwargs = utils.get_method(fft)
220220
else:
221221
fft,fft_kwargs = fft
222222

@@ -295,10 +295,10 @@ def generate_noise_2d_fft_filter(F, randstate=np.random, seed=None, fft_method=N
295295
randstate.seed(seed)
296296

297297
if fft_method is None:
298-
fft,fft_kwargs = fft_module.get_method("numpy")
298+
fft,fft_kwargs = utils.get_method("numpy")
299299
else:
300300
if type(fft_method) == str:
301-
fft,fft_kwargs = fft_module.get_method(fft_method)
301+
fft,fft_kwargs = utils.get_method(fft_method)
302302
else:
303303
fft,fft_kwargs = fft
304304

@@ -379,7 +379,7 @@ def initialize_nonparam_2d_ssft_filter(X, **kwargs):
379379
rm_rdisc = kwargs.get('rm_disc', True)
380380
fft = kwargs.get("fft_method", "numpy")
381381
if type(fft) == str:
382-
fft,fft_kwargs = fft_module.get_method(fft)
382+
fft,fft_kwargs = utils.get_method(fft)
383383
else:
384384
fft,fft_kwargs = fft
385385

@@ -492,7 +492,7 @@ def initialize_nonparam_2d_nested_filter(X, gridres=1.0, **kwargs):
492492
rm_rdisc = kwargs.get('rm_disc', True)
493493
fft = kwargs.get("fft_method", "numpy")
494494
if type(fft) == str:
495-
fft,fft_kwargs = fft_module.get_method(fft)
495+
fft,fft_kwargs = utils.get_method(fft)
496496
else:
497497
fft,fft_kwargs = fft
498498

@@ -623,7 +623,7 @@ def generate_noise_2d_ssft_filter(F, randstate=np.random, seed=None, **kwargs):
623623
win_type = kwargs.get('win_type', 'flat-hanning')
624624
fft = kwargs.get("fft_method", "numpy")
625625
if type(fft) == str:
626-
fft,fft_kwargs = fft_module.get_method(fft)
626+
fft,fft_kwargs = utils.get_method(fft)
627627
else:
628628
fft,fft_kwargs = fft
629629

pysteps/utils/fft.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,15 @@
33
try:
44
import pyfftw.interfaces.numpy_fft as pyfftw_fft
55
import pyfftw
6-
# TODO: Caching and multithreading are currently disabled because they give
7-
# a segfault with dask.
6+
# TODO: Caching is currently disabled because it gives segfault with dask.
87
pyfftw.interfaces.cache.disable()
98
pyfftw_imported = True
109
except ImportError:
1110
pyfftw_imported = False
1211
import scipy.fftpack as scipy_fft
1312
import numpy.fft as numpy_fft
1413

15-
from pysteps.exceptions import MissingOptionalDependency
16-
1714
# use numpy implementation of rfft2/irfft2 because they have not been
1815
# implemented in scipy.fftpack
1916
scipy_fft.rfft2 = numpy_fft.rfft2
2017
scipy_fft.irfft2 = numpy_fft.irfft2
21-
22-
def get_method(name):
23-
"""Return a callable function for the FFT method corresponding to the given
24-
name.
25-
26-
Parameters
27-
----------
28-
name : str
29-
The name of the method. The available options are 'numpy', 'scipy' and
30-
'pyfftw'
31-
32-
Returns
33-
-------
34-
out : tuple
35-
A two-element tuple containing the FFT module and a dictionary of
36-
default keyword arguments for calling the FFT method. Each module
37-
implements the numpy.fft interface.
38-
"""
39-
if name == "numpy":
40-
return numpy_fft,{}
41-
elif name == "scipy":
42-
return scipy_fft,{}
43-
elif name == "pyfftw":
44-
if not pyfftw_imported:
45-
raise MissingOptionalDependency("pyfftw is required but it is not installed")
46-
return pyfftw_fft,{"threads":1, "planner_effort":"FFTW_ESTIMATE"}
47-
else:
48-
raise ValueError("unknown method %s, the available methods are 'numpy', 'scipy' and 'pyfftw'" % name)

pysteps/utils/interface.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import transformation
44
from . import dimension
55
from . import fft
6+
from pysteps.exceptions import MissingOptionalDependency
67

78
def get_method(name):
89
"""Return a callable function for the utility method corresponding to the
@@ -97,13 +98,29 @@ def donothing(R, metadata, *args, **kwargs):
9798
methods_objects["square"] = dimension.square_domain
9899
methods_objects["upscale"] = dimension.aggregate_fields_space
99100
# FFT methods
100-
methods_objects["numpy_fft"] = fft.get_method("numpy")
101-
methods_objects["scipy_fft"] = fft.get_method("scipy")
102-
methods_objects["pyfftw_fft"] = fft.get_method("pyfftw")
101+
methods_objects["numpy"] = _get_fft_method("numpy")
102+
methods_objects["scipy"] = _get_fft_method("scipy")
103103

104104
try:
105-
return methods_objects[name]
105+
if name == "pyfftw":
106+
return _get_fft_method("pyfftw")
107+
else:
108+
return methods_objects[name]
106109

107110
except KeyError as e:
108111
raise ValueError("Unknown method %s\n" % e +
109112
"Supported methods:%s" % str(methods_objects.keys()))
113+
114+
def _get_fft_method(name):
115+
if name == "numpy":
116+
return fft.numpy_fft,{}
117+
elif name == "scipy":
118+
return fft.scipy_fft,{}
119+
elif name == "pyfftw":
120+
if not fft.pyfftw_imported:
121+
raise MissingOptionalDependency("pyfftw is required but it is not installed")
122+
# TODO: Multithreading is currently disabled because it gives segfault
123+
# with dask.
124+
return fft.pyfftw_fft,{"threads":1, "planner_effort":"FFTW_ESTIMATE"}
125+
else:
126+
raise ValueError("unknown method %s, the available methods are 'numpy', 'scipy' and 'pyfftw'" % name)

0 commit comments

Comments
 (0)