Skip to content

Commit d24d714

Browse files
committed
Use the FFT interface module
1 parent e4ebb33 commit d24d714

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

pysteps/cascade/decomposition.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,7 @@
2323
"""
2424

2525
import numpy as np
26-
# Use the pyfftw interface if it is installed. If not, fall back to the fftpack
27-
# interface provided by SciPy, and finally to numpy if SciPy is not installed.
28-
try:
29-
import pyfftw.interfaces.numpy_fft as fft
30-
import pyfftw
31-
# TODO: Caching and multithreading currently disabled because they give a
32-
# segfault with dask.
33-
#pyfftw.interfaces.cache.enable()
34-
fft_kwargs = {"threads":1, "planner_effort":"FFTW_ESTIMATE"}
35-
except ImportError:
36-
import scipy.fftpack as fft
37-
fft_kwargs = {}
38-
except ImportError:
39-
import numpy.fft as fft
40-
fft_kwargs = {}
26+
from .. utils import fft as fft_module
4127

4228
def decomposition_fft(X, filter, **kwargs):
4329
"""Decompose a 2d input field into multiple spatial scales by using the Fast
@@ -53,17 +39,21 @@ def decomposition_fft(X, filter, **kwargs):
5339
5440
Other Parameters
5541
----------------
42+
fft_method : tuple
43+
A tuple defining the FFT method to use (see utils.fft.get_method).
44+
Defaults to numpy.fft.
5645
MASK : array_like
57-
Optional mask to use for computing the statistics for the cascade levels.
58-
Pixels with MASK==False are excluded from the computations.
46+
Optional mask to use for computing the statistics for the cascade levels.
47+
Pixels with MASK==False are excluded from the computations.
5948
6049
Returns
6150
-------
6251
out : ndarray
63-
A dictionary described in the module documentation. The number of cascade
64-
levels is determined from the filter (see bandpass_filters.py).
52+
A dictionary described in the module documentation. The number of cascade
53+
levels is determined from the filter (see bandpass_filters.py).
6554
6655
"""
56+
fft,fft_kwargs = kwargs.get("fft", fft_module.get_method("numpy"))
6757
MASK = kwargs.get("MASK", None)
6858

6959
if len(X.shape) != 2:

0 commit comments

Comments
 (0)