@@ -29,6 +29,7 @@ def __init__(
2929 ifftshift_before : bool = False ,
3030 fftshift_after : bool = False ,
3131 dtype : DTypeLike = "complex128" ,
32+ ** kwargs_fft ,
3233 ) -> None :
3334 super ().__init__ (
3435 dims = dims ,
@@ -45,7 +46,7 @@ def __init__(
4546 warnings .warn (
4647 f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to { self .cdtype } ."
4748 )
48-
49+ self . _kwargs_fft = kwargs_fft
4950 self ._norm_kwargs = {"norm" : None } # equivalent to "backward" in Numpy/Scipy
5051 if self .norm is _FFTNorms .ORTHO :
5152 self ._norm_kwargs ["norm" ] = "ortho"
@@ -61,13 +62,17 @@ def _matvec(self, x: NDArray) -> NDArray:
6162 if not self .clinear :
6263 x = np .real (x )
6364 if self .real :
64- y = np .fft .rfftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
65+ y = np .fft .rfftn (
66+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
67+ )
6568 # Apply scaling to obtain a correct adjoint for this operator
6669 y = np .swapaxes (y , - 1 , self .axes [- 1 ])
6770 y [..., 1 : 1 + (self .nffts [- 1 ] - 1 ) // 2 ] *= np .sqrt (2 )
6871 y = np .swapaxes (y , self .axes [- 1 ], - 1 )
6972 else :
70- y = np .fft .fftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
73+ y = np .fft .fftn (
74+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
75+ )
7176 if self .norm is _FFTNorms .ONE_OVER_N :
7277 y *= self ._scale
7378 y = y .astype (self .cdtype )
@@ -85,9 +90,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
8590 x = np .swapaxes (x , - 1 , self .axes [- 1 ])
8691 x [..., 1 : 1 + (self .nffts [- 1 ] - 1 ) // 2 ] /= np .sqrt (2 )
8792 x = np .swapaxes (x , self .axes [- 1 ], - 1 )
88- y = np .fft .irfftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
93+ y = np .fft .irfftn (
94+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
95+ )
8996 else :
90- y = np .fft .ifftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
97+ y = np .fft .ifftn (
98+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
99+ )
91100 if self .norm is _FFTNorms .NONE :
92101 y *= self ._scale
93102 for ax , nfft in zip (self .axes , self .nffts ):
@@ -122,6 +131,7 @@ def __init__(
122131 ifftshift_before : bool = False ,
123132 fftshift_after : bool = False ,
124133 dtype : DTypeLike = "complex128" ,
134+ ** kwargs_fft ,
125135 ) -> None :
126136 super ().__init__ (
127137 dims = dims ,
@@ -134,7 +144,7 @@ def __init__(
134144 fftshift_after = fftshift_after ,
135145 dtype = dtype ,
136146 )
137-
147+ self . _kwargs_fft = kwargs_fft
138148 self ._norm_kwargs = {"norm" : None } # equivalent to "backward" in Numpy/Scipy
139149 if self .norm is _FFTNorms .ORTHO :
140150 self ._norm_kwargs ["norm" ] = "ortho"
@@ -151,13 +161,17 @@ def _matvec(self, x: NDArray) -> NDArray:
151161 if not self .clinear :
152162 x = np .real (x )
153163 if self .real :
154- y = sp_fft .rfftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
164+ y = sp_fft .rfftn (
165+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
166+ )
155167 # Apply scaling to obtain a correct adjoint for this operator
156168 y = np .swapaxes (y , - 1 , self .axes [- 1 ])
157169 y [..., 1 : 1 + (self .nffts [- 1 ] - 1 ) // 2 ] *= np .sqrt (2 )
158170 y = np .swapaxes (y , self .axes [- 1 ], - 1 )
159171 else :
160- y = sp_fft .fftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
172+ y = sp_fft .fftn (
173+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
174+ )
161175 if self .norm is _FFTNorms .ONE_OVER_N :
162176 y *= self ._scale
163177 if self .fftshift_after .any ():
@@ -175,9 +189,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
175189 x = np .swapaxes (x , - 1 , self .axes [- 1 ])
176190 x [..., 1 : 1 + (self .nffts [- 1 ] - 1 ) // 2 ] /= np .sqrt (2 )
177191 x = np .swapaxes (x , self .axes [- 1 ], - 1 )
178- y = sp_fft .irfftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
192+ y = sp_fft .irfftn (
193+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
194+ )
179195 else :
180- y = sp_fft .ifftn (x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs )
196+ y = sp_fft .ifftn (
197+ x , s = self .nffts , axes = self .axes , ** self ._norm_kwargs , ** self ._kwargs_fft
198+ )
181199 if self .norm is _FFTNorms .NONE :
182200 y *= self ._scale
183201 for ax , nfft in zip (self .axes , self .nffts ):
@@ -209,6 +227,7 @@ def FFTND(
209227 engine : str = "scipy" ,
210228 dtype : DTypeLike = "complex128" ,
211229 name : str = "F" ,
230+ ** kwargs_fft ,
212231):
213232 r"""N-dimensional Fast-Fourier Transform.
214233
@@ -311,6 +330,8 @@ def FFTND(
311330 .. versionadded:: 2.0.0
312331
313332 Name of operator (to be used by :func:`pylops.utils.describe.describe`)
333+ **kwargs_fft
334+ Arbitrary keyword arguments to be passed to the selected fft method
314335
315336 Attributes
316337 ----------
@@ -396,6 +417,7 @@ def FFTND(
396417 ifftshift_before = ifftshift_before ,
397418 fftshift_after = fftshift_after ,
398419 dtype = dtype ,
420+ ** kwargs_fft ,
399421 )
400422 elif engine == "scipy" :
401423 f = _FFTND_scipy (
@@ -408,6 +430,7 @@ def FFTND(
408430 ifftshift_before = ifftshift_before ,
409431 fftshift_after = fftshift_after ,
410432 dtype = dtype ,
433+ ** kwargs_fft ,
411434 )
412435 else :
413436 raise NotImplementedError ("engine must be numpy or scipy" )
0 commit comments