@@ -37,6 +37,7 @@ def __init__(
3737 ifftshift_before : bool = False ,
3838 fftshift_after : bool = False ,
3939 dtype : DTypeLike = "complex128" ,
40+ ** kwargs_fft ,
4041 ) -> None :
4142 super ().__init__ (
4243 dims = dims ,
@@ -54,6 +55,7 @@ def __init__(
5455 f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be casted to { self .cdtype } ."
5556 )
5657
58+ self ._kwargs_fft = kwargs_fft
5759 self ._norm_kwargs = {"norm" : None } # equivalent to "backward" in Numpy/Scipy
5860 if self .norm is _FFTNorms .ORTHO :
5961 self ._norm_kwargs ["norm" ] = "ortho"
@@ -74,14 +76,18 @@ def _matvec(self, x: NDArray) -> NDArray:
7476 if not self .clinear :
7577 x = ncp .real (x )
7678 if self .real :
77- y = ncp .fft .rfft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
79+ y = ncp .fft .rfft (
80+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
81+ )
7882 # Apply scaling to obtain a correct adjoint for this operator
7983 y = ncp .swapaxes (y , - 1 , self .axis )
8084 # y[..., 1 : 1 + (self.nfft - 1) // 2] *= ncp.sqrt(2)
8185 y = inplace_multiply (ncp .sqrt (2 ), y , self .slice )
8286 y = ncp .swapaxes (y , self .axis , - 1 )
8387 else :
84- y = ncp .fft .fft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
88+ y = ncp .fft .fft (
89+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
90+ )
8591 if self .norm is _FFTNorms .ONE_OVER_N :
8692 y *= self ._scale
8793 if self .fftshift_after :
@@ -101,9 +107,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
101107 # x[..., 1 : 1 + (self.nfft - 1) // 2] /= ncp.sqrt(2)
102108 x = inplace_divide (ncp .sqrt (2 ), x , self .slice )
103109 x = ncp .swapaxes (x , self .axis , - 1 )
104- y = ncp .fft .irfft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
110+ y = ncp .fft .irfft (
111+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
112+ )
105113 else :
106- y = ncp .fft .ifft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
114+ y = ncp .fft .ifft (
115+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
116+ )
107117 if self .norm is _FFTNorms .NONE :
108118 y *= self ._scale
109119
@@ -139,6 +149,7 @@ def __init__(
139149 ifftshift_before : bool = False ,
140150 fftshift_after : bool = False ,
141151 dtype : DTypeLike = "complex128" ,
152+ ** kwargs_fft ,
142153 ) -> None :
143154 super ().__init__ (
144155 dims = dims ,
@@ -152,6 +163,7 @@ def __init__(
152163 dtype = dtype ,
153164 )
154165
166+ self ._kwargs_fft = kwargs_fft
155167 self ._norm_kwargs = {"norm" : None } # equivalent to "backward" in Numpy/Scipy
156168 if self .norm is _FFTNorms .ORTHO :
157169 self ._norm_kwargs ["norm" ] = "ortho"
@@ -167,13 +179,17 @@ def _matvec(self, x: NDArray) -> NDArray:
167179 if not self .clinear :
168180 x = np .real (x )
169181 if self .real :
170- y = scipy .fft .rfft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
182+ y = scipy .fft .rfft (
183+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
184+ )
171185 # Apply scaling to obtain a correct adjoint for this operator
172186 y = np .swapaxes (y , - 1 , self .axis )
173187 y [..., 1 : 1 + (self .nfft - 1 ) // 2 ] *= np .sqrt (2 )
174188 y = np .swapaxes (y , self .axis , - 1 )
175189 else :
176- y = scipy .fft .fft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
190+ y = scipy .fft .fft (
191+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
192+ )
177193 if self .norm is _FFTNorms .ONE_OVER_N :
178194 y *= self ._scale
179195 if self .fftshift_after :
@@ -190,9 +206,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
190206 x = np .swapaxes (x , - 1 , self .axis )
191207 x [..., 1 : 1 + (self .nfft - 1 ) // 2 ] /= np .sqrt (2 )
192208 x = np .swapaxes (x , self .axis , - 1 )
193- y = scipy .fft .irfft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
209+ y = scipy .fft .irfft (
210+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
211+ )
194212 else :
195- y = scipy .fft .ifft (x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs )
213+ y = scipy .fft .ifft (
214+ x , n = self .nfft , axis = self .axis , ** self ._norm_kwargs , ** self ._kwargs_fft
215+ )
196216 if self .norm is _FFTNorms .NONE :
197217 y *= self ._scale
198218
@@ -227,7 +247,7 @@ def __init__(
227247 ifftshift_before : bool = False ,
228248 fftshift_after : bool = False ,
229249 dtype : DTypeLike = "complex128" ,
230- ** kwargs_fftw ,
250+ ** kwargs_fft ,
231251 ) -> None :
232252 if np .dtype (dtype ) == np .float16 :
233253 warnings .warn (
@@ -236,13 +256,13 @@ def __init__(
236256 dtype = np .float32
237257
238258 for badop in ["ortho" , "normalise_idft" ]:
239- if badop in kwargs_fftw :
259+ if badop in kwargs_fft :
240260 if badop == "ortho" and norm == "ortho" :
241261 continue
242262 warnings .warn (
243263 f"FFTW option '{ badop } ' will be overwritten by norm={ norm } "
244264 )
245- del kwargs_fftw [badop ]
265+ del kwargs_fft [badop ]
246266
247267 super ().__init__ (
248268 dims = dims ,
@@ -298,10 +318,10 @@ def __init__(
298318 self ._scale = 1.0 / self .nfft
299319
300320 self .fftplan = pyfftw .FFTW (
301- self .x , self .y , axes = (self .axis ,), direction = "FFTW_FORWARD" , ** kwargs_fftw
321+ self .x , self .y , axes = (self .axis ,), direction = "FFTW_FORWARD" , ** kwargs_fft
302322 )
303323 self .ifftplan = pyfftw .FFTW (
304- self .y , self .x , axes = (self .axis ,), direction = "FFTW_BACKWARD" , ** kwargs_fftw
324+ self .y , self .x , axes = (self .axis ,), direction = "FFTW_BACKWARD" , ** kwargs_fft
305325 )
306326
307327 @reshaped
@@ -386,7 +406,7 @@ def FFT(
386406 engine : str = "numpy" ,
387407 dtype : DTypeLike = "complex128" ,
388408 name : str = "F" ,
389- ** kwargs_fftw ,
409+ ** kwargs_fft ,
390410) -> LinearOperator :
391411 r"""One dimensional Fast-Fourier Transform.
392412
@@ -479,9 +499,9 @@ def FFT(
479499 .. versionadded:: 2.0.0
480500
481501 Name of operator (to be used by :func:`pylops.utils.describe.describe`)
482- **kwargs_fftw
483- Arbitrary keyword arguments
484- for :py:class:`pyfftw.FTTW`
502+ **kwargs_fft
503+ Arbitrary keyword arguments to be passed to the selected fft method
504+
485505
486506 Attributes
487507 ----------
@@ -557,7 +577,7 @@ def FFT(
557577 ifftshift_before = ifftshift_before ,
558578 fftshift_after = fftshift_after ,
559579 dtype = dtype ,
560- ** kwargs_fftw ,
580+ ** kwargs_fft ,
561581 )
562582 elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None ):
563583 if engine == "fftw" and pyfftw_message is not None :
@@ -572,6 +592,7 @@ def FFT(
572592 ifftshift_before = ifftshift_before ,
573593 fftshift_after = fftshift_after ,
574594 dtype = dtype ,
595+ ** kwargs_fft ,
575596 )
576597 elif engine == "scipy" :
577598 f = _FFT_scipy (
@@ -584,6 +605,7 @@ def FFT(
584605 ifftshift_before = ifftshift_before ,
585606 fftshift_after = fftshift_after ,
586607 dtype = dtype ,
608+ ** kwargs_fft ,
587609 )
588610 else :
589611 raise NotImplementedError ("engine must be numpy, fftw or scipy" )
0 commit comments