Skip to content

Commit 6234478

Browse files
committed
feat: added kwargs_fft to all fft operators
1 parent 042ea41 commit 6234478

File tree

4 files changed

+240
-37
lines changed

4 files changed

+240
-37
lines changed

pylops/signalprocessing/fft.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
ifftshift_before: bool = False,
3838
fftshift_after: bool = False,
3939
dtype: DTypeLike = "complex128",
40+
**kwargs_fft,
4041
) -> None:
4142
super().__init__(
4243
dims=dims,
@@ -54,6 +55,7 @@ def __init__(
5455
f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be casted to {self.cdtype}."
5556
)
5657

58+
self._kwargs_fft = kwargs_fft
5759
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
5860
if self.norm is _FFTNorms.ORTHO:
5961
self._norm_kwargs["norm"] = "ortho"
@@ -74,14 +76,18 @@ def _matvec(self, x: NDArray) -> NDArray:
7476
if not self.clinear:
7577
x = ncp.real(x)
7678
if self.real:
77-
y = ncp.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
79+
y = ncp.fft.rfft(
80+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
81+
)
7882
# Apply scaling to obtain a correct adjoint for this operator
7983
y = ncp.swapaxes(y, -1, self.axis)
8084
# y[..., 1 : 1 + (self.nfft - 1) // 2] *= ncp.sqrt(2)
8185
y = inplace_multiply(ncp.sqrt(2), y, self.slice)
8286
y = ncp.swapaxes(y, self.axis, -1)
8387
else:
84-
y = ncp.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
88+
y = ncp.fft.fft(
89+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
90+
)
8591
if self.norm is _FFTNorms.ONE_OVER_N:
8692
y *= self._scale
8793
if self.fftshift_after:
@@ -101,9 +107,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
101107
# x[..., 1 : 1 + (self.nfft - 1) // 2] /= ncp.sqrt(2)
102108
x = inplace_divide(ncp.sqrt(2), x, self.slice)
103109
x = ncp.swapaxes(x, self.axis, -1)
104-
y = ncp.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
110+
y = ncp.fft.irfft(
111+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
112+
)
105113
else:
106-
y = ncp.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
114+
y = ncp.fft.ifft(
115+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
116+
)
107117
if self.norm is _FFTNorms.NONE:
108118
y *= self._scale
109119

@@ -139,6 +149,7 @@ def __init__(
139149
ifftshift_before: bool = False,
140150
fftshift_after: bool = False,
141151
dtype: DTypeLike = "complex128",
152+
**kwargs_fft,
142153
) -> None:
143154
super().__init__(
144155
dims=dims,
@@ -152,6 +163,7 @@ def __init__(
152163
dtype=dtype,
153164
)
154165

166+
self._kwargs_fft = kwargs_fft
155167
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
156168
if self.norm is _FFTNorms.ORTHO:
157169
self._norm_kwargs["norm"] = "ortho"
@@ -167,13 +179,17 @@ def _matvec(self, x: NDArray) -> NDArray:
167179
if not self.clinear:
168180
x = np.real(x)
169181
if self.real:
170-
y = scipy.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
182+
y = scipy.fft.rfft(
183+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
184+
)
171185
# Apply scaling to obtain a correct adjoint for this operator
172186
y = np.swapaxes(y, -1, self.axis)
173187
y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2)
174188
y = np.swapaxes(y, self.axis, -1)
175189
else:
176-
y = scipy.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
190+
y = scipy.fft.fft(
191+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
192+
)
177193
if self.norm is _FFTNorms.ONE_OVER_N:
178194
y *= self._scale
179195
if self.fftshift_after:
@@ -190,9 +206,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
190206
x = np.swapaxes(x, -1, self.axis)
191207
x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2)
192208
x = np.swapaxes(x, self.axis, -1)
193-
y = scipy.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
209+
y = scipy.fft.irfft(
210+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
211+
)
194212
else:
195-
y = scipy.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
213+
y = scipy.fft.ifft(
214+
x, n=self.nfft, axis=self.axis, **self._norm_kwargs, **self._kwargs_fft
215+
)
196216
if self.norm is _FFTNorms.NONE:
197217
y *= self._scale
198218

@@ -227,7 +247,7 @@ def __init__(
227247
ifftshift_before: bool = False,
228248
fftshift_after: bool = False,
229249
dtype: DTypeLike = "complex128",
230-
**kwargs_fftw,
250+
**kwargs_fft,
231251
) -> None:
232252
if np.dtype(dtype) == np.float16:
233253
warnings.warn(
@@ -236,13 +256,13 @@ def __init__(
236256
dtype = np.float32
237257

238258
for badop in ["ortho", "normalise_idft"]:
239-
if badop in kwargs_fftw:
259+
if badop in kwargs_fft:
240260
if badop == "ortho" and norm == "ortho":
241261
continue
242262
warnings.warn(
243263
f"FFTW option '{badop}' will be overwritten by norm={norm}"
244264
)
245-
del kwargs_fftw[badop]
265+
del kwargs_fft[badop]
246266

247267
super().__init__(
248268
dims=dims,
@@ -298,10 +318,10 @@ def __init__(
298318
self._scale = 1.0 / self.nfft
299319

300320
self.fftplan = pyfftw.FFTW(
301-
self.x, self.y, axes=(self.axis,), direction="FFTW_FORWARD", **kwargs_fftw
321+
self.x, self.y, axes=(self.axis,), direction="FFTW_FORWARD", **kwargs_fft
302322
)
303323
self.ifftplan = pyfftw.FFTW(
304-
self.y, self.x, axes=(self.axis,), direction="FFTW_BACKWARD", **kwargs_fftw
324+
self.y, self.x, axes=(self.axis,), direction="FFTW_BACKWARD", **kwargs_fft
305325
)
306326

307327
@reshaped
@@ -386,7 +406,7 @@ def FFT(
386406
engine: str = "numpy",
387407
dtype: DTypeLike = "complex128",
388408
name: str = "F",
389-
**kwargs_fftw,
409+
**kwargs_fft,
390410
) -> LinearOperator:
391411
r"""One dimensional Fast-Fourier Transform.
392412
@@ -479,9 +499,9 @@ def FFT(
479499
.. versionadded:: 2.0.0
480500
481501
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
482-
**kwargs_fftw
483-
Arbitrary keyword arguments
484-
for :py:class:`pyfftw.FTTW`
502+
**kwargs_fft
503+
Arbitrary keyword arguments to be passed to the selected fft method
504+
485505
486506
Attributes
487507
----------
@@ -557,7 +577,7 @@ def FFT(
557577
ifftshift_before=ifftshift_before,
558578
fftshift_after=fftshift_after,
559579
dtype=dtype,
560-
**kwargs_fftw,
580+
**kwargs_fft,
561581
)
562582
elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None):
563583
if engine == "fftw" and pyfftw_message is not None:
@@ -572,6 +592,7 @@ def FFT(
572592
ifftshift_before=ifftshift_before,
573593
fftshift_after=fftshift_after,
574594
dtype=dtype,
595+
**kwargs_fft,
575596
)
576597
elif engine == "scipy":
577598
f = _FFT_scipy(
@@ -584,6 +605,7 @@ def FFT(
584605
ifftshift_before=ifftshift_before,
585606
fftshift_after=fftshift_after,
586607
dtype=dtype,
608+
**kwargs_fft,
587609
)
588610
else:
589611
raise NotImplementedError("engine must be numpy, fftw or scipy")

pylops/signalprocessing/fft2d.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
ifftshift_before: bool = False,
3131
fftshift_after: bool = False,
3232
dtype: DTypeLike = "complex128",
33+
**kwargs_fft,
3334
) -> None:
3435
super().__init__(
3536
dims=dims,
@@ -56,6 +57,7 @@ def __init__(
5657
self.f1, self.f2 = self.fs
5758
del self.fs
5859

60+
self._kwargs_fft = kwargs_fft
5961
self._norm_kwargs: Dict[str, Union[None, str]] = {
6062
"norm": None
6163
} # equivalent to "backward" in Numpy/Scipy
@@ -74,13 +76,17 @@ def _matvec(self, x):
7476
if not self.clinear:
7577
x = ncp.real(x)
7678
if self.real:
77-
y = ncp.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
79+
y = ncp.fft.rfft2(
80+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
81+
)
7882
# Apply scaling to obtain a correct adjoint for this operator
7983
y = ncp.swapaxes(y, -1, self.axes[-1])
8084
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
8185
y = ncp.swapaxes(y, self.axes[-1], -1)
8286
else:
83-
y = ncp.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
87+
y = ncp.fft.fft2(
88+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
89+
)
8490
if self.norm is _FFTNorms.ONE_OVER_N:
8591
y *= self._scale
8692
y = y.astype(self.cdtype)
@@ -99,9 +105,13 @@ def _rmatvec(self, x):
99105
x = ncp.swapaxes(x, -1, self.axes[-1])
100106
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
101107
x = ncp.swapaxes(x, self.axes[-1], -1)
102-
y = ncp.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
108+
y = ncp.fft.irfft2(
109+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
110+
)
103111
else:
104-
y = ncp.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
112+
y = ncp.fft.ifft2(
113+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
114+
)
105115
if self.norm is _FFTNorms.NONE:
106116
y *= self._scale
107117
if self.nffts[0] > self.dims[self.axes[0]]:
@@ -137,6 +147,7 @@ def __init__(
137147
ifftshift_before: bool = False,
138148
fftshift_after: bool = False,
139149
dtype: DTypeLike = "complex128",
150+
**kwargs_fft,
140151
) -> None:
141152
super().__init__(
142153
dims=dims,
@@ -159,6 +170,7 @@ def __init__(
159170
self.f1, self.f2 = self.fs
160171
del self.fs
161172

173+
self._kwargs_fft = kwargs_fft
162174
self._norm_kwargs: Dict[str, Union[None, str]] = {
163175
"norm": None
164176
} # equivalent to "backward" in Numpy/Scipy
@@ -176,13 +188,17 @@ def _matvec(self, x):
176188
if not self.clinear:
177189
x = np.real(x)
178190
if self.real:
179-
y = scipy.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
191+
y = scipy.fft.rfft2(
192+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
193+
)
180194
# Apply scaling to obtain a correct adjoint for this operator
181195
y = np.swapaxes(y, -1, self.axes[-1])
182196
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
183197
y = np.swapaxes(y, self.axes[-1], -1)
184198
else:
185-
y = scipy.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
199+
y = scipy.fft.fft2(
200+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
201+
)
186202
if self.norm is _FFTNorms.ONE_OVER_N:
187203
y *= self._scale
188204
if self.fftshift_after.any():
@@ -199,9 +215,13 @@ def _rmatvec(self, x):
199215
x = np.swapaxes(x, -1, self.axes[-1])
200216
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
201217
x = np.swapaxes(x, self.axes[-1], -1)
202-
y = scipy.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
218+
y = scipy.fft.irfft2(
219+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
220+
)
203221
else:
204-
y = scipy.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
222+
y = scipy.fft.ifft2(
223+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
224+
)
205225
if self.norm is _FFTNorms.NONE:
206226
y *= self._scale
207227
y = np.take(y, range(self.dims[self.axes[0]]), axis=self.axes[0])
@@ -230,6 +250,7 @@ def FFT2D(
230250
engine: str = "numpy",
231251
dtype: DTypeLike = "complex128",
232252
name: str = "F",
253+
**kwargs_fft,
233254
) -> LinearOperator:
234255
r"""Two dimensional Fast-Fourier Transform.
235256
@@ -328,6 +349,10 @@ def FFT2D(
328349
.. versionadded:: 2.0.0
329350
330351
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
352+
**kwargs_fft
353+
.. versionadded:: 2.5.0
354+
355+
Arbitrary keyword arguments to be passed to the selected fft method
331356
332357
Attributes
333358
----------
@@ -409,6 +434,7 @@ def FFT2D(
409434
ifftshift_before=ifftshift_before,
410435
fftshift_after=fftshift_after,
411436
dtype=dtype,
437+
**kwargs_fft,
412438
)
413439
elif engine == "scipy":
414440
f = _FFT2D_scipy(
@@ -421,6 +447,7 @@ def FFT2D(
421447
ifftshift_before=ifftshift_before,
422448
fftshift_after=fftshift_after,
423449
dtype=dtype,
450+
**kwargs_fft,
424451
)
425452
else:
426453
raise NotImplementedError("engine must be numpy or scipy")

0 commit comments

Comments
 (0)