Skip to content

Commit 9bd1b1f

Browse files
committed
feat: dims and dimsd for relevant basicoperators
1 parent 10fef7b commit 9bd1b1f

File tree

17 files changed

+114
-80
lines changed

17 files changed

+114
-80
lines changed

pylops/basicoperators/CausalIntegration.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from pylops import LinearOperator
6-
from pylops.utils._internal import _value_or_list_like_to_array
6+
from pylops.utils._internal import _value_or_list_like_to_tuple
77

88

99
class CausalIntegration(LinearOperator):
@@ -97,16 +97,18 @@ def __init__(
9797
kind="full",
9898
removefirst=False,
9999
):
100-
self.dims = _value_or_list_like_to_array(dims)
100+
self.dims = _value_or_list_like_to_tuple(dims)
101101
self.axis = axis
102102
self.sampling = sampling
103103
self.kind = kind
104104
if kind == "full" and halfcurrent: # ensure backcompatibility
105105
self.kind = "half"
106106
self.removefirst = removefirst
107-
self.dimsd = self.dims.copy()
107+
dimsd = list(self.dims)
108108
if self.removefirst:
109-
self.dimsd[self.axis] -= 1
109+
dimsd[self.axis] -= 1
110+
self.dimsd = tuple(dimsd)
111+
110112
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
111113
self.dtype = np.dtype(dtype)
112114
self.explicit = False

pylops/basicoperators/Conj.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pylops import LinearOperator
4+
from pylops.utils._internal import _value_or_list_like_to_tuple
45
from pylops.utils.backend import get_array_module
56

67

@@ -41,7 +42,9 @@ class Conj(LinearOperator):
4142
"""
4243

4344
def __init__(self, dims, dtype="complex128"):
44-
self.shape = (np.prod(np.array(dims)), np.prod(np.array(dims)))
45+
self.dims = self.dimsd = _value_or_list_like_to_tuple(dims)
46+
47+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
4548
self.dtype = np.dtype(dtype)
4649
self.explicit = False
4750
self.clinear = False

pylops/basicoperators/Diagonal.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
from pylops import LinearOperator
6+
from pylops.utils._internal import _value_or_list_like_to_tuple
67
from pylops.utils.backend import get_array_module, to_cupy_conditional
78

89

@@ -64,28 +65,26 @@ def __init__(self, diag, dims=None, axis=-1, dtype="float64"):
6465
self.diag = diag.ravel()
6566
self.complex = True if ncp.iscomplexobj(self.diag) else False
6667

67-
if dims is None:
68-
self.shape = (len(self.diag), len(self.diag))
69-
self.dims = None
70-
self.reshape = False
71-
else:
72-
diagdims = [1] * len(dims)
73-
diagdims[axis] = dims[axis]
74-
self.diag = self.diag.reshape(diagdims)
75-
self.shape = (np.prod(dims), np.prod(dims))
76-
self.dims = dims
77-
self.reshape = True
68+
ncp = get_array_module(diag)
69+
self.diag = diag.ravel()
70+
self.complex = True if ncp.iscomplexobj(self.diag) else False
71+
self.dims = self.dimsd = (
72+
(len(self.diag),) if dims is None else _value_or_list_like_to_tuple(dims)
73+
)
74+
75+
diagdims = np.ones_like(self.dims)
76+
diagdims[axis] = self.dims[axis]
77+
self.diag = self.diag.reshape(diagdims)
78+
79+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
7880
self.dtype = np.dtype(dtype)
7981
self.explicit = False
8082

8183
def _matvec(self, x):
8284
if type(self.diag) != type(x):
8385
self.diag = to_cupy_conditional(x, self.diag)
84-
if not self.reshape:
85-
y = self.diag * x.ravel()
86-
else:
87-
x = x.reshape(self.dims)
88-
y = self.diag * x
86+
x = x.reshape(self.dims)
87+
y = self.diag * x
8988
return y.ravel()
9089

9190
def _rmatvec(self, x):
@@ -95,11 +94,8 @@ def _rmatvec(self, x):
9594
diagadj = self.diag.conj()
9695
else:
9796
diagadj = self.diag
98-
if not self.reshape:
99-
y = diagadj * x.ravel()
100-
else:
101-
x = x.reshape(self.dims)
102-
y = diagadj * x
97+
x = x.reshape(self.dims)
98+
y = diagadj * x
10399
return y.ravel()
104100

105101
def matrix(self):

pylops/basicoperators/DirectionalDerivative.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ def FirstDirectionalDerivative(
6767
else:
6868
Dop = Diagonal(v.ravel(), dtype=dtype)
6969
Sop = Sum(dims=[len(dims)] + list(dims), axis=0, dtype=dtype)
70-
ddop = Sop * Dop * Gop
71-
return LinearOperator(ddop)
70+
ddop = LinearOperator(Sop * Dop * Gop)
71+
ddop.dims = ddop.dimsd = dims
72+
ddop.sampling = sampling
73+
return ddop
7274

7375

7476
def SecondDirectionalDerivative(dims, v, sampling=1, edge=False, dtype="float64"):
@@ -118,5 +120,7 @@ def SecondDirectionalDerivative(dims, v, sampling=1, edge=False, dtype="float64"
118120
in the literature.
119121
"""
120122
Dop = FirstDirectionalDerivative(dims, v, sampling=sampling, edge=edge, dtype=dtype)
121-
ddop = -Dop.H * Dop
122-
return LinearOperator(ddop)
123+
ddop = LinearOperator(-Dop.H * Dop)
124+
ddop.dims = ddop.dimsd = dims
125+
ddop.sampling = sampling
126+
return ddop

pylops/basicoperators/FirstDerivative.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from numpy.core.multiarray import normalize_axis_index
55

66
from pylops import LinearOperator
7-
from pylops.utils._internal import _value_or_list_like_to_array
7+
from pylops.utils._internal import _value_or_list_like_to_tuple
88
from pylops.utils.backend import get_array_module
99

1010

@@ -73,13 +73,13 @@ def __init__(
7373
dtype="float64",
7474
kind="centered",
7575
):
76-
self.dims = _value_or_list_like_to_array(dims)
76+
self.dims = self.dimsd = _value_or_list_like_to_tuple(dims)
7777
self.axis = normalize_axis_index(axis, len(self.dims))
7878
self.sampling = sampling
7979
self.edge = edge
8080
self.kind = kind
81-
N = np.prod(self.dims)
82-
self.shape = (N, N)
81+
82+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
8383
self.dtype = np.dtype(dtype)
8484
self.explicit = False
8585

pylops/basicoperators/Flip.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from pylops import LinearOperator
6-
from pylops.utils._internal import _value_or_list_like_to_array
6+
from pylops.utils._internal import _value_or_list_like_to_tuple
77

88

99
class Flip(LinearOperator):
@@ -46,10 +46,10 @@ class Flip(LinearOperator):
4646
"""
4747

4848
def __init__(self, dims, axis=-1, dtype="float64"):
49-
self.dims = _value_or_list_like_to_array(dims)
49+
self.dims = self.dimsd = _value_or_list_like_to_tuple(dims)
5050
self.axis = axis
51-
N = np.prod(self.dims)
52-
self.shape = (N, N)
51+
52+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
5353
self.dtype = np.dtype(dtype)
5454
self.explicit = False
5555

pylops/basicoperators/Gradient.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pylops.basicoperators import FirstDerivative, VStack
4+
from pylops.utils._internal import _value_or_list_like_to_tuple
45

56

67
def Gradient(dims, sampling=1, edge=False, dtype="float64", kind="centered"):
@@ -58,9 +59,9 @@ def Gradient(dims, sampling=1, edge=False, dtype="float64", kind="centered"):
5859
axes are instead summed together.
5960
6061
"""
62+
dims = _value_or_list_like_to_tuple(dims)
6163
ndims = len(dims)
62-
if isinstance(sampling, (int, float)):
63-
sampling = [sampling] * ndims
64+
sampling = _value_or_list_like_to_tuple(sampling, repeat=ndims)
6465

6566
gop = VStack(
6667
[
@@ -75,4 +76,9 @@ def Gradient(dims, sampling=1, edge=False, dtype="float64", kind="centered"):
7576
for iax in range(ndims)
7677
]
7778
)
79+
gop.dims = dims
80+
gop.dimsd = (ndims, *gop.dims)
81+
gop.sampling = sampling
82+
gop.edge = edge
83+
gop.kind = kind
7884
return gop

pylops/basicoperators/Imag.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pylops import LinearOperator
4+
from pylops.utils._internal import _value_or_list_like_to_tuple
45
from pylops.utils.backend import get_array_module
56

67

@@ -43,7 +44,9 @@ class Imag(LinearOperator):
4344
"""
4445

4546
def __init__(self, dims, dtype="complex128"):
46-
self.shape = (np.prod(np.array(dims)), np.prod(np.array(dims)))
47+
self.dims = self.dimsd = _value_or_list_like_to_tuple(dims)
48+
49+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
4750
self.dtype = np.dtype(dtype)
4851
self.rdtype = np.real(np.ones(1, self.dtype)).dtype
4952
self.explicit = False

pylops/basicoperators/Pad.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pylops import LinearOperator
4+
from pylops.utils._internal import _value_or_list_like_to_tuple
45

56

67
class Pad(LinearOperator):
@@ -60,14 +61,19 @@ class Pad(LinearOperator):
6061
def __init__(self, dims, pad, dtype="float64"):
6162
if np.any(np.array(pad) < 0):
6263
raise ValueError("Padding must be positive or zero")
63-
self.dims = dims
64+
self.reshape = False if isinstance(dims, int) else True
65+
self.dims = _value_or_list_like_to_tuple(dims)
6466
self.pad = pad
65-
self.reshape = False if isinstance(self.dims, int) else True
6667
if self.reshape:
67-
self.dimsd = [dim + p[0] + p[1] for dim, p in zip(dims, pad)]
68+
dimsd = [
69+
dim + before + after
70+
for dim, (before, after) in zip(self.dims, self.pad)
71+
]
6872
else:
69-
self.dimsd = dims + pad[0] + pad[1]
70-
self.shape = (np.prod(np.array(self.dimsd)), np.prod(np.array(self.dims)))
73+
dimsd = [self.dims[0] + pad[0] + pad[1]]
74+
self.dimsd = tuple(dimsd)
75+
76+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
7177
self.dtype = np.dtype(dtype)
7278
self.explicit = False
7379

@@ -82,8 +88,8 @@ def _matvec(self, x):
8288
def _rmatvec(self, x):
8389
if self.reshape:
8490
y = x.reshape(self.dimsd)
85-
for ax, pad in enumerate(self.pad):
86-
y = np.take(y, np.arange(pad[0], pad[0] + self.dims[ax]), axis=ax)
91+
for ax, (before, _) in enumerate(self.pad):
92+
y = np.take(y, np.arange(before, before + self.dims[ax]), axis=ax)
8793
else:
88-
y = x[self.pad[0] : self.pad[0] + self.dims]
94+
y = x[self.pad[0] : self.pad[0] + self.dims[0]]
8995
return y.ravel()

pylops/basicoperators/Real.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from pylops import LinearOperator
4+
from pylops.utils._internal import _value_or_list_like_to_tuple
45
from pylops.utils.backend import get_array_module
56

67

@@ -43,7 +44,9 @@ class Real(LinearOperator):
4344
"""
4445

4546
def __init__(self, dims, dtype="complex128"):
46-
self.shape = (np.prod(np.array(dims)), np.prod(np.array(dims)))
47+
self.dims = self.dimsd = _value_or_list_like_to_tuple(dims)
48+
49+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
4750
self.dtype = np.dtype(dtype)
4851
self.rdtype = np.real(np.ones(1, self.dtype)).dtype
4952
self.explicit = False

0 commit comments

Comments
 (0)