Skip to content

Commit d7475f7

Browse files
committed
Allow the user to choose the FFT method
1 parent 0ba5ff4 commit d7475f7

File tree

1 file changed

+61
-20
lines changed

1 file changed

+61
-20
lines changed

pysteps/noise/fftgenerators.py

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,8 @@
3131

3232
import numpy as np
3333
from scipy import optimize
34-
# Use the pyfftw interface if it is installed. If not, fall back to the fftpack
35-
# interface provided by SciPy, and finally to numpy if SciPy is not installed.
36-
try:
37-
import pyfftw.interfaces.numpy_fft as fft
38-
import pyfftw
39-
# TODO: Caching and multithreading currently disabled because they give a
40-
# segfault with dask.
41-
#pyfftw.interfaces.cache.enable()
42-
fft_kwargs = {"threads":1, "planner_effort":"FFTW_ESTIMATE"}
43-
except ImportError:
44-
import scipy.fftpack as fft
45-
fft_kwargs = {}
46-
except ImportError:
47-
import numpy.fft as fft
48-
fft_kwargs = {}
4934
from .. import utils
35+
from .. utils import fft as fft_module
5036

5137
def initialize_param_2d_fft_filter(X, **kwargs):
5238
"""Takes one ore more 2d input fields, fits two spectral slopes, beta1 and beta2,
@@ -78,6 +64,9 @@ def initialize_param_2d_fft_filter(X, **kwargs):
7864
doplot : bool
7965
Plot the fit.
8066
Default : False
67+
fft_method : tuple
68+
A string or a (function,kwargs) tuple defining the FFT method to use
69+
(see utils.fft.get_method). Defaults to "numpy".
8170
8271
Returns
8372
-------
@@ -97,6 +86,11 @@ def initialize_param_2d_fft_filter(X, **kwargs):
9786
weighted = kwargs.get('weighted', True)
9887
rm_rdisc = kwargs.get('rm_disc', True)
9988
doplot = kwargs.get('doplot', False)
89+
fft = kwargs.get("fft_method", "numpy")
90+
if type(fft) == str:
91+
fft,fft_kwargs = fft_module.get_method(fft)
92+
else:
93+
fft,fft_kwargs = fft
10094

10195
X = X.copy()
10296

@@ -200,6 +194,9 @@ def initialize_nonparam_2d_fft_filter(X, **kwargs):
200194
rm_rdisc : bool
201195
Whether or not to remove the rain/no-rain disconituity. It assumes no-rain
202196
pixels are assigned with lowest value.
197+
fft_method : tuple
198+
A string or a (function,kwargs) tuple defining the FFT method to use
199+
(see utils.fft.get_method). Defaults to "numpy".
203200
204201
Returns
205202
-------
@@ -217,6 +214,11 @@ def initialize_nonparam_2d_fft_filter(X, **kwargs):
217214
donorm = kwargs.get('donorm', False)
218215
rm_rdisc = kwargs.get('rm_rdisc', True)
219216
use_full_fft = kwargs.get('use_full_fft', False)
217+
fft = kwargs.get("fft_method", "numpy")
218+
if type(fft) == str:
219+
fft,fft_kwargs = fft_module.get_method(fft)
220+
else:
221+
fft,fft_kwargs = fft
220222

221223
X = X.copy()
222224

@@ -257,7 +259,7 @@ def initialize_nonparam_2d_fft_filter(X, **kwargs):
257259

258260
return {"F":np.abs(F), "input_shape":X.shape[1:], "use_full_fft":use_full_fft}
259261

260-
def generate_noise_2d_fft_filter(F, randstate=np.random, seed=None):
262+
def generate_noise_2d_fft_filter(F, randstate=np.random, seed=None, fft_method=None):
261263
"""Produces a field of correlated noise using global Fourier filtering.
262264
263265
Parameters
@@ -270,6 +272,9 @@ def generate_noise_2d_fft_filter(F, randstate=np.random, seed=None):
270272
Optional random generator to use. If set to None, use numpy.random.
271273
seed : int
272274
Value to set a seed for the generator. None will not set the seed.
275+
fft_method : tuple
276+
A string or a (function,kwargs) tuple defining the FFT method to use
277+
(see utils.fft.get_method). Defaults to "numpy".
273278
274279
Returns
275280
-------
@@ -289,6 +294,14 @@ def generate_noise_2d_fft_filter(F, randstate=np.random, seed=None):
289294
if seed is not None:
290295
randstate.seed(seed)
291296

297+
if fft_method is None:
298+
fft,fft_kwargs = fft_module.get_method("numpy")
299+
else:
300+
if type(fft_method) == str:
301+
fft,fft_kwargs = fft_module.get_method(fft_method)
302+
else:
303+
fft,fft_kwargs = fft
304+
292305
# produce fields of white noise
293306
N = randstate.randn(input_shape[0], input_shape[1])
294307

@@ -335,6 +348,9 @@ def initialize_nonparam_2d_ssft_filter(X, **kwargs):
335348
rm_rdisc : bool
336349
Whether or not to remove the rain/no-rain disconituity. It assumes no-rain
337350
pixels are assigned with lowest value.
351+
fft_method : tuple
352+
A string or a (function,kwargs) tuple defining the FFT method to use
353+
(see utils.fft.get_method). Defaults to "numpy".
338354
339355
Returns
340356
-------
@@ -361,6 +377,11 @@ def initialize_nonparam_2d_ssft_filter(X, **kwargs):
361377
overlap = kwargs.get('overlap', 0.3)
362378
war_thr = kwargs.get('war_thr', 0.1)
363379
rm_rdisc = kwargs.get('rm_disc', True)
380+
fft = kwargs.get("fft_method", "numpy")
381+
if type(fft) == str:
382+
fft,fft_kwargs = fft_module.get_method(fft)
383+
else:
384+
fft,fft_kwargs = fft
364385

365386
X = X.copy()
366387

@@ -391,7 +412,8 @@ def initialize_nonparam_2d_ssft_filter(X, **kwargs):
391412

392413
# domain fourier filter
393414
F0 = initialize_nonparam_2d_fft_filter(X, win_type=win_type, donorm=True,
394-
use_full_fft=True)["F"]
415+
use_full_fft=True,
416+
fft_method=(fft,fft_kwargs))["F"]
395417
# and allocate it to the final grid
396418
F = np.zeros((num_windows_y, num_windows_x, F0.shape[0], F0.shape[1]))
397419
F += F0[np.newaxis, np.newaxis, :, :]
@@ -415,7 +437,8 @@ def initialize_nonparam_2d_ssft_filter(X, **kwargs):
415437
if war > war_thr:
416438
# the new filter
417439
F[i, j, : ,:] = initialize_nonparam_2d_fft_filter(X*mask[None, :, :],
418-
win_type=None, donorm=True, use_full_fft=True)["F"]
440+
win_type=None, donorm=True, use_full_fft=True,
441+
fft_method=(fft,fft_kwargs))["F"]
419442

420443
return {"F":F, "input_shape":X.shape[1:], "use_full_fft":True}
421444

@@ -446,6 +469,9 @@ def initialize_nonparam_2d_nested_filter(X, gridres=1.0, **kwargs):
446469
rm_rdisc : bool
447470
Whether or not to remove the rain/no-rain disconituity. It assumes no-rain
448471
pixels are assigned with lowest value.
472+
fft_method : tuple
473+
A string or a (function,kwargs) tuple defining the FFT method to use
474+
(see utils.fft.get_method). Defaults to "numpy".
449475
450476
Returns
451477
-------
@@ -464,6 +490,11 @@ def initialize_nonparam_2d_nested_filter(X, gridres=1.0, **kwargs):
464490
win_type = kwargs.get('win_type', 'flat-hanning')
465491
war_thr = kwargs.get('war_thr', 0.1)
466492
rm_rdisc = kwargs.get('rm_disc', True)
493+
fft = kwargs.get("fft_method", "numpy")
494+
if type(fft) == str:
495+
fft,fft_kwargs = fft_module.get_method(fft)
496+
else:
497+
fft,fft_kwargs = fft
467498

468499
X = X.copy()
469500

@@ -498,7 +529,8 @@ def initialize_nonparam_2d_nested_filter(X, gridres=1.0, **kwargs):
498529

499530
# domain fourier filter
500531
F0 = initialize_nonparam_2d_fft_filter(X, win_type=win_type, donorm=True,
501-
use_full_fft=True)["F"]
532+
use_full_fft=True,
533+
fft_method=(fft,fft_kwargs))["F"]
502534
# and allocate it to the final grid
503535
F = np.zeros((2**max_level, 2**max_level, F0.shape[0], F0.shape[1]))
504536
F += F0[np.newaxis, np.newaxis, :, :]
@@ -522,7 +554,8 @@ def initialize_nonparam_2d_nested_filter(X, gridres=1.0, **kwargs):
522554
if war > war_thr:
523555
# the new filter
524556
newfilter = initialize_nonparam_2d_fft_filter(X*mask[None, :, :],
525-
win_type=None, donorm=True, use_full_fft=True)["F"]
557+
win_type=None, donorm=True, use_full_fft=True,
558+
fft_method=(fft,fft_kwargs))["F"]
526559

527560
# compute logistic function to define weights as function of frequency
528561
# k controls the shape of the weighting function
@@ -566,6 +599,9 @@ def generate_noise_2d_ssft_filter(F, randstate=np.random, seed=None, **kwargs):
566599
win_type : string ['hanning', 'flat-hanning']
567600
Type of window used for localization.
568601
Default : flat-hanning
602+
fft_method : tuple
603+
A string or a (function,kwargs) tuple defining the FFT method to use
604+
(see utils.fft.get_method). Defaults to "numpy".
569605
570606
Returns
571607
-------
@@ -585,6 +621,11 @@ def generate_noise_2d_ssft_filter(F, randstate=np.random, seed=None, **kwargs):
585621
# defaults
586622
overlap = kwargs.get('overlap', 0.2)
587623
win_type = kwargs.get('win_type', 'flat-hanning')
624+
fft = kwargs.get("fft_method", "numpy")
625+
if type(fft) == str:
626+
fft,fft_kwargs = fft_module.get_method(fft)
627+
else:
628+
fft,fft_kwargs = fft
588629

589630
# set the seed
590631
if seed is not None:

0 commit comments

Comments
 (0)