Skip to content

Commit 42c6b2f

Browse files
authored
feature: added kwargs to FFTND (#577)
1 parent 74e8c68 commit 42c6b2f

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
build:
77
strategy:
88
matrix:
9-
platform: [ ubuntu-latest, macos-latest ]
9+
platform: [ ubuntu-latest, macos-13 ]
1010
python-version: ["3.8", "3.9", "3.10", "3.11"]
1111

1212
runs-on: ${{ matrix.platform }}

pylops/signalprocessing/fftnd.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
ifftshift_before: bool = False,
3030
fftshift_after: bool = False,
3131
dtype: DTypeLike = "complex128",
32+
**kwargs_fft,
3233
) -> None:
3334
super().__init__(
3435
dims=dims,
@@ -45,7 +46,7 @@ def __init__(
4546
warnings.warn(
4647
f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}."
4748
)
48-
49+
self._kwargs_fft = kwargs_fft
4950
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
5051
if self.norm is _FFTNorms.ORTHO:
5152
self._norm_kwargs["norm"] = "ortho"
@@ -61,13 +62,17 @@ def _matvec(self, x: NDArray) -> NDArray:
6162
if not self.clinear:
6263
x = np.real(x)
6364
if self.real:
64-
y = np.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
65+
y = np.fft.rfftn(
66+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
67+
)
6568
# Apply scaling to obtain a correct adjoint for this operator
6669
y = np.swapaxes(y, -1, self.axes[-1])
6770
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
6871
y = np.swapaxes(y, self.axes[-1], -1)
6972
else:
70-
y = np.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
73+
y = np.fft.fftn(
74+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
75+
)
7176
if self.norm is _FFTNorms.ONE_OVER_N:
7277
y *= self._scale
7378
y = y.astype(self.cdtype)
@@ -85,9 +90,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
8590
x = np.swapaxes(x, -1, self.axes[-1])
8691
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
8792
x = np.swapaxes(x, self.axes[-1], -1)
88-
y = np.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
93+
y = np.fft.irfftn(
94+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
95+
)
8996
else:
90-
y = np.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
97+
y = np.fft.ifftn(
98+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
99+
)
91100
if self.norm is _FFTNorms.NONE:
92101
y *= self._scale
93102
for ax, nfft in zip(self.axes, self.nffts):
@@ -122,6 +131,7 @@ def __init__(
122131
ifftshift_before: bool = False,
123132
fftshift_after: bool = False,
124133
dtype: DTypeLike = "complex128",
134+
**kwargs_fft,
125135
) -> None:
126136
super().__init__(
127137
dims=dims,
@@ -134,7 +144,7 @@ def __init__(
134144
fftshift_after=fftshift_after,
135145
dtype=dtype,
136146
)
137-
147+
self._kwargs_fft = kwargs_fft
138148
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
139149
if self.norm is _FFTNorms.ORTHO:
140150
self._norm_kwargs["norm"] = "ortho"
@@ -151,13 +161,17 @@ def _matvec(self, x: NDArray) -> NDArray:
151161
if not self.clinear:
152162
x = np.real(x)
153163
if self.real:
154-
y = sp_fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
164+
y = sp_fft.rfftn(
165+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
166+
)
155167
# Apply scaling to obtain a correct adjoint for this operator
156168
y = np.swapaxes(y, -1, self.axes[-1])
157169
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
158170
y = np.swapaxes(y, self.axes[-1], -1)
159171
else:
160-
y = sp_fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
172+
y = sp_fft.fftn(
173+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
174+
)
161175
if self.norm is _FFTNorms.ONE_OVER_N:
162176
y *= self._scale
163177
if self.fftshift_after.any():
@@ -175,9 +189,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
175189
x = np.swapaxes(x, -1, self.axes[-1])
176190
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
177191
x = np.swapaxes(x, self.axes[-1], -1)
178-
y = sp_fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
192+
y = sp_fft.irfftn(
193+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
194+
)
179195
else:
180-
y = sp_fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
196+
y = sp_fft.ifftn(
197+
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
198+
)
181199
if self.norm is _FFTNorms.NONE:
182200
y *= self._scale
183201
for ax, nfft in zip(self.axes, self.nffts):
@@ -209,6 +227,7 @@ def FFTND(
209227
engine: str = "scipy",
210228
dtype: DTypeLike = "complex128",
211229
name: str = "F",
230+
**kwargs_fft,
212231
):
213232
r"""N-dimensional Fast-Fourier Transform.
214233
@@ -311,6 +330,8 @@ def FFTND(
311330
.. versionadded:: 2.0.0
312331
313332
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
333+
**kwargs_fft
334+
Arbitrary keyword arguments to be passed to the selected fft method
314335
315336
Attributes
316337
----------
@@ -396,6 +417,7 @@ def FFTND(
396417
ifftshift_before=ifftshift_before,
397418
fftshift_after=fftshift_after,
398419
dtype=dtype,
420+
**kwargs_fft,
399421
)
400422
elif engine == "scipy":
401423
f = _FFTND_scipy(
@@ -408,6 +430,7 @@ def FFTND(
408430
ifftshift_before=ifftshift_before,
409431
fftshift_after=fftshift_after,
410432
dtype=dtype,
433+
**kwargs_fft,
411434
)
412435
else:
413436
raise NotImplementedError("engine must be numpy or scipy")

0 commit comments

Comments
 (0)