Skip to content

Commit 2e2fe15

Browse files
committed
feat: dims and dimsd for relevant signalprocessing operators
1 parent 9bd1b1f commit 2e2fe15

File tree

13 files changed

+138
-76
lines changed

13 files changed

+138
-76
lines changed

pylops/signalprocessing/Bilinear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def __init__(self, iava, dims, dtype="float64"):
8484

8585
# define dimension of data
8686
ndims = len(dims)
87-
self.dims = dims
88-
self.dimsd = [len(iava[1])] + list(dims[2:])
87+
self.dims = tuple(dims)
88+
self.dimsd = tuple([len(iava[1])] + list(dims[2:]))
8989

9090
# find indices and weights
9191
self.iava_t = ncp.floor(iava[0]).astype(int)

pylops/signalprocessing/Convolve1D.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import warnings
22

33
import numpy as np
4-
from numpy.core.multiarray import normalize_axis_index
54

65
from pylops import LinearOperator
7-
from pylops.utils._internal import _value_or_list_like_to_array
6+
from pylops.utils._internal import _value_or_list_like_to_tuple
87
from pylops.utils.backend import (
98
get_convolve,
109
get_fftconvolve,
@@ -118,7 +117,7 @@ class Convolve1D(LinearOperator):
118117
"""
119118

120119
def __init__(self, dims, h, offset=0, axis=-1, dtype="float64", method=None):
121-
self.dims = _value_or_list_like_to_array(dims)
120+
self.dims = self.dimsd = _value_or_list_like_to_tuple(dims)
122121
self.axis = axis
123122

124123
if offset > len(h) - 1:
@@ -149,7 +148,7 @@ def __init__(self, dims, h, offset=0, axis=-1, dtype="float64", method=None):
149148

150149
# choose method and function handle
151150
self.convfunc, self.method = _choose_convfunc(h, method, self.dimsorig)
152-
self.shape = (np.prod(self.dims), np.prod(self.dims))
151+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
153152
self.dtype = np.dtype(dtype)
154153
self.explicit = False
155154

pylops/signalprocessing/ConvolveND.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from numpy.core.multiarray import normalize_axis_index
55

66
from pylops import LinearOperator
7-
from pylops.utils._internal import _value_or_list_like_to_array
7+
from pylops.utils._internal import _value_or_list_like_to_tuple
88
from pylops.utils.backend import (
99
get_array_module,
1010
get_convolve,
@@ -64,7 +64,7 @@ def __init__(
6464
dtype="float64",
6565
):
6666
ncp = get_array_module(h)
67-
self.dims = _value_or_list_like_to_array(dims)
67+
self.dims = self.dimsd = _value_or_list_like_to_tuple(dims)
6868
self.axes = (
6969
np.arange(len(self.dims))
7070
if axes is None
@@ -106,7 +106,7 @@ def __init__(
106106
self.correlate = get_correlate(h)
107107
self.method = method
108108

109-
self.shape = (np.prod(self.dims), np.prod(self.dims))
109+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
110110
self.dtype = np.dtype(dtype)
111111
self.explicit = False
112112

pylops/signalprocessing/DWT.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pylops import LinearOperator
88
from pylops.basicoperators import Pad
9+
from pylops.utils._internal import _value_or_list_like_to_tuple
910

1011
try:
1112
import pywt
@@ -104,15 +105,16 @@ def __init__(self, dims, axis=-1, wavelet="haar", level=1, dtype="float64"):
104105
if isinstance(dims, int):
105106
dims = (dims,)
106107

108+
self.dims = _value_or_list_like_to_tuple(dims)
107109
# define padding for length to be power of 2
108-
ndimpow2 = max(2 ** ceil(log(dims[axis], 2)), 2 ** level)
109-
pad = [(0, 0)] * len(dims)
110-
pad[axis] = (0, ndimpow2 - dims[axis])
111-
self.pad = Pad(dims, pad)
112-
self.dims = dims
110+
ndimpow2 = max(2 ** ceil(log(self.dims[axis], 2)), 2 ** level)
111+
pad = [(0, 0)] * len(self.dims)
112+
pad[axis] = (0, ndimpow2 - self.dims[axis])
113+
self.pad = Pad(self.dims, pad)
113114
self.axis = axis
114-
self.dimsd = list(dims)
115-
self.dimsd[self.axis] = ndimpow2
115+
dimsd = list(self.dims)
116+
dimsd[self.axis] = ndimpow2
117+
self.dimsd = tuple(dimsd)
116118

117119
# apply transform to find out slices
118120
_, self.sl = pywt.coeffs_to_array(
@@ -130,7 +132,8 @@ def __init__(self, dims, axis=-1, wavelet="haar", level=1, dtype="float64"):
130132
self.waveletadj = _adjointwavelet(wavelet)
131133
self.level = level
132134
self.reshape = True if len(self.dims) > 1 else False
133-
self.shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
135+
136+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
134137
self.dtype = np.dtype(dtype)
135138
self.explicit = False
136139

pylops/signalprocessing/DWT2D.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,18 @@ def __init__(self, dims, axes=(-2, -1), wavelet="haar", level=1, dtype="float64"
8080
raise ModuleNotFoundError(pywt_message)
8181
_checkwavelet(wavelet)
8282

83+
self.dims = tuple(dims)
8384
# define padding for length to be power of 2
84-
ndimpow2 = [max(2 ** ceil(log(dims[ax], 2)), 2 ** level) for ax in axes]
85-
pad = [(0, 0)] * len(dims)
85+
ndimpow2 = [max(2 ** ceil(log(self.dims[ax], 2)), 2 ** level) for ax in axes]
86+
pad = [(0, 0)] * len(self.dims)
8687
for i, ax in enumerate(axes):
87-
pad[ax] = (0, ndimpow2[i] - dims[ax])
88-
self.pad = Pad(dims, pad)
89-
self.dims = dims
88+
pad[ax] = (0, ndimpow2[i] - self.dims[ax])
89+
self.pad = Pad(self.dims, pad)
9090
self.axes = axes
91-
self.dimsd = list(dims)
91+
dimsd = list(self.dims)
9292
for i, ax in enumerate(axes):
93-
self.dimsd[ax] = ndimpow2[i]
93+
dimsd[ax] = ndimpow2[i]
94+
self.dimsd = tuple(dimsd)
9495

9596
# apply transform once again to find out slices
9697
_, self.sl = pywt.coeffs_to_array(
@@ -106,7 +107,8 @@ def __init__(self, dims, axes=(-2, -1), wavelet="haar", level=1, dtype="float64"
106107
self.wavelet = wavelet
107108
self.waveletadj = _adjointwavelet(wavelet)
108109
self.level = level
109-
self.shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
110+
111+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
110112
self.dtype = np.dtype(dtype)
111113
self.explicit = False
112114

pylops/signalprocessing/FFT.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _matvec(self, x):
8484
return y
8585

8686
def _rmatvec(self, x):
87-
x = np.reshape(x, self.dims_fft)
87+
x = np.reshape(x, self.dimsd)
8888
if self.fftshift_after:
8989
x = np.fft.ifftshift(x, axes=self.axis)
9090
if self.real:
@@ -175,7 +175,7 @@ def _matvec(self, x):
175175
return y
176176

177177
def _rmatvec(self, x):
178-
x = np.reshape(x, self.dims_fft)
178+
x = np.reshape(x, self.dimsd)
179179
if self.fftshift_after:
180180
x = scipy.fft.ifftshift(x, axes=self.axis)
181181
if self.real:
@@ -255,29 +255,30 @@ def __init__(
255255
f"fftw backend returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}."
256256
)
257257

258-
self.dims_t = self.dims.copy()
259-
self.dims_t[self.axis] = self.nfft
258+
dims_t = list(self.dims)
259+
dims_t[self.axis] = self.nfft
260+
self.dims_t = dims_t
260261

261262
# define padding(fftw requires the user to provide padded input signal)
262263
self.pad = np.zeros((self.ndim, 2), dtype=int)
263264
if self.real:
264265
if self.nfft % 2:
265266
self.pad[self.axis, 1] = (
266-
2 * (self.dims_fft[self.axis] - 1) + 1 - self.dims[self.axis]
267+
2 * (self.dimsd[self.axis] - 1) + 1 - self.dims[self.axis]
267268
)
268269
else:
269270
self.pad[self.axis, 1] = (
270-
2 * (self.dims_fft[self.axis] - 1) - self.dims[self.axis]
271+
2 * (self.dimsd[self.axis] - 1) - self.dims[self.axis]
271272
)
272273
else:
273-
self.pad[self.axis, 1] = self.dims_fft[self.axis] - self.dims[self.axis]
274+
self.pad[self.axis, 1] = self.dimsd[self.axis] - self.dims[self.axis]
274275
self.dopad = True if np.sum(self.pad) > 0 else False
275276

276277
# create empty arrays and plans for fft/ifft
277278
self.x = pyfftw.empty_aligned(
278279
self.dims_t, dtype=self.rdtype if real else self.cdtype
279280
)
280-
self.y = pyfftw.empty_aligned(self.dims_fft, dtype=self.cdtype)
281+
self.y = pyfftw.empty_aligned(self.dimsd, dtype=self.cdtype)
281282

282283
# Use FFTW without norm-related keywords above. In this case, FFTW standard
283284
# behavior is to scale with 1/N on the inverse transform. The _scale below
@@ -327,7 +328,7 @@ def _matvec(self, x):
327328
return y.ravel()
328329

329330
def _rmatvec(self, x):
330-
x = np.reshape(x, self.dims_fft)
331+
x = np.reshape(x, self.dimsd)
331332
if self.fftshift_after:
332333
x = np.fft.ifftshift(x, axes=self.axis)
333334

@@ -479,10 +480,14 @@ def FFT(
479480
480481
Attributes
481482
----------
482-
dims_fft : :obj:`tuple`
483+
dimsd : :obj:`tuple`
483484
Shape of the array after the forward, but before linearization.
484485
485-
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dims_fft)``.
486+
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``.
487+
dims_fft : :obj:`tuple`
488+
489+
.. deprecated:: 2.0.0
490+
Use ``dimsd`` instead.
486491
f : :obj:`numpy.ndarray`
487492
Discrete Fourier Transform sample frequencies
488493
real : :obj:`bool`

pylops/signalprocessing/FFT2D.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _matvec(self, x):
7979
return y.ravel()
8080

8181
def _rmatvec(self, x):
82-
x = np.reshape(x, self.dims_fft)
82+
x = np.reshape(x, self.dimsd)
8383
if self.fftshift_after.any():
8484
x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
8585
if self.real:
@@ -177,7 +177,7 @@ def _matvec(self, x):
177177
return y.ravel()
178178

179179
def _rmatvec(self, x):
180-
x = np.reshape(x, self.dims_fft)
180+
x = np.reshape(x, self.dimsd)
181181
if self.fftshift_after.any():
182182
x = scipy.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
183183
if self.real:
@@ -312,10 +312,14 @@ def FFT2D(
312312
313313
Attributes
314314
----------
315-
dims_fft : :obj:`tuple`
315+
dimsd : :obj:`tuple`
316316
Shape of the array after the forward, but before linearization.
317317
318-
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dims_fft)``.
318+
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``.
319+
dims_fft : :obj:`tuple`
320+
321+
.. deprecated:: 2.0.0
322+
Use ``dimsd`` instead.
319323
f1 : :obj:`numpy.ndarray`
320324
Discrete Fourier Transform sample frequencies along ``axes[0]``
321325
f2 : :obj:`numpy.ndarray`

pylops/signalprocessing/FFTND.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _matvec(self, x):
7070
return y.ravel()
7171

7272
def _rmatvec(self, x):
73-
x = np.reshape(x, self.dims_fft)
73+
x = np.reshape(x, self.dimsd)
7474
if self.fftshift_after.any():
7575
x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
7676
if self.real:
@@ -158,7 +158,7 @@ def _matvec(self, x):
158158
return y.ravel()
159159

160160
def _rmatvec(self, x):
161-
x = np.reshape(x, self.dims_fft)
161+
x = np.reshape(x, self.dimsd)
162162
if self.fftshift_after.any():
163163
x = scipy.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
164164
if self.real:
@@ -301,10 +301,14 @@ def FFTND(
301301
302302
Attributes
303303
----------
304-
dims_fft : :obj:`tuple`
304+
dimsd : :obj:`tuple`
305305
Shape of the array after the forward, but before linearization.
306306
307-
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dims_fft)``.
307+
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``.
308+
dims_fft : :obj:`tuple`
309+
310+
.. deprecated:: 2.0.0
311+
Use ``dimsd`` instead.
308312
fs : :obj:`tuple`
309313
Each element of the tuple corresponds to the Discrete Fourier Transform
310314
sample frequencies along the respective direction given by ``axes``.

pylops/signalprocessing/Interp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pylops import LinearOperator
77
from pylops.basicoperators import Diagonal, MatrixMult, Restriction, Transpose
8-
from pylops.utils._internal import _value_or_list_like_to_array
8+
from pylops.utils._internal import _value_or_list_like_to_tuple
99
from pylops.utils.backend import get_array_module
1010

1111
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
@@ -188,7 +188,7 @@ def Interp(dims, iava, axis=-1, kind="linear", dtype="float64"):
188188
:math:`i,j` possible combinations.
189189
190190
"""
191-
dims = _value_or_list_like_to_array(dims)
191+
dims = _value_or_list_like_to_tuple(dims)
192192

193193
if kind == "nearest":
194194
interpop, iava = _nearestinterp(dims, iava, axis=axis, dtype=dtype)
@@ -198,4 +198,4 @@ def Interp(dims, iava, axis=-1, kind="linear", dtype="float64"):
198198
interpop = _sincinterp(dims, iava, axis=axis, dtype=dtype)
199199
else:
200200
raise NotImplementedError("kind is not correct...")
201-
return LinearOperator(interpop), iava
201+
return interpop, iava

pylops/signalprocessing/Seislet.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -386,17 +386,18 @@ def __init__(
386386
raise NotImplementedError("kind should be haar or linear")
387387

388388
# define padding for length to be power of 2
389-
dims = slopes.shape
390-
ndimpow2 = 2 ** ceil(log(dims[0], 2))
391-
pad = [(0, 0)] * len(dims)
392-
pad[0] = (0, ndimpow2 - dims[0])
393-
self.pad = Pad(dims, pad)
394-
self.dims = list(dims)
395-
self.dims[0] = ndimpow2
396-
self.nx, self.nt = self.dims
389+
self.dims = slopes.shape
390+
dimsd = list(self.dims)
391+
ndimpow2 = 2 ** ceil(log(dimsd[0], 2))
392+
pad = [(0, 0)] * len(dimsd)
393+
pad[0] = (0, ndimpow2 - dimsd[0])
394+
self.pad = Pad(dimsd, pad)
395+
dimsd[0] = ndimpow2
396+
self.dimsd = dimsd
397+
self.nx, self.nt = self.dimsd
397398

398399
# define levels
399-
nlevels_max = int(np.log2(self.dims[0]))
400+
nlevels_max = int(np.log2(self.dimsd[0]))
400401
self.levels_size = np.flip(np.array([2 ** i for i in range(nlevels_max)]))
401402
if level is not None:
402403
self.levels_size = self.levels_size[:level]
@@ -408,15 +409,15 @@ def __init__(
408409
self.levels_cum = np.insert(self.levels_cum, 0, 0)
409410

410411
self.dx, self.dt = sampling
411-
self.slopes = (self.pad * slopes.ravel()).reshape(self.dims)
412+
self.slopes = (self.pad * slopes.ravel()).reshape(self.dimsd)
412413
self.inv = inv
413-
self.shape = (int(np.prod(self.slopes.size)), int(np.prod(slopes.size)))
414+
self.shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
414415
self.dtype = np.dtype(dtype)
415416
self.explicit = False
416417

417418
def _matvec(self, x):
418419
x = self.pad.matvec(x)
419-
x = np.reshape(x, self.dims)
420+
x = np.reshape(x, self.dimsd)
420421
y = np.zeros((np.sum(self.levels_size) + self.levels_size[-1], self.nt))
421422
for ilevel in range(self.level):
422423
odd = x[1::2]
@@ -437,7 +438,7 @@ def _matvec(self, x):
437438

438439
def _rmatvec(self, x):
439440
if not self.inv:
440-
x = np.reshape(x, self.dims)
441+
x = np.reshape(x, self.dimsd)
441442
y = x[self.levels_cum[-1] :]
442443
for ilevel in range(self.level, 0, -1):
443444
res = x[self.levels_cum[ilevel - 1] : self.levels_cum[ilevel]]
@@ -472,7 +473,7 @@ def _rmatvec(self, x):
472473
return y
473474

474475
def inverse(self, x):
475-
x = np.reshape(x, self.dims)
476+
x = np.reshape(x, self.dimsd)
476477
y = x[self.levels_cum[-1] :]
477478
for ilevel in range(self.level, 0, -1):
478479
res = x[self.levels_cum[ilevel - 1] : self.levels_cum[ilevel]]

0 commit comments

Comments
 (0)