Skip to content

Commit 5401d9c

Browse files
committed
minor: added dims and dimsd to MatrixMult
Changed dims in input parameters to otherdims to be able to have dims and dimsd consistent with the other operators.
1 parent 67aefb2 commit 5401d9c

File tree

4 files changed

+29
-25
lines changed

4 files changed

+29
-25
lines changed

pylops/avo/poststack.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _PoststackLinearModelling(
8787
M = ncp.dot(C, D)
8888
if sparse:
8989
M = get_csc_matrix(wav)(M)
90-
Pop = _MatrixMult(M, dims=spatdims, dtype=dtype, **args_MatrixMult)
90+
Pop = _MatrixMult(M, otherdims=spatdims, dtype=dtype, **args_MatrixMult)
9191
else:
9292
# Create wavelet operator
9393
if len(wav.shape) == 1:
@@ -102,7 +102,7 @@ def _PoststackLinearModelling(
102102
else:
103103
Cop = _MatrixMult(
104104
nonstationary_convmtx(wav, nt0, hc=wav.shape[1] // 2, pad=(nt0, nt0)),
105-
dims=spatdims,
105+
otherdims=spatdims,
106106
dtype=dtype,
107107
**args_MatrixMult
108108
)
@@ -350,7 +350,7 @@ def PoststackInversion(
350350
minv = get_lstsq(data)(PP, datarn, **kwargs_solver)[0]
351351
else:
352352
# solve regularized normal equations simultaneously
353-
PPop_reg = MatrixMult(PP, dims=nspatprod)
353+
PPop_reg = MatrixMult(PP, otherdims=nspatprod)
354354
if ncp == np:
355355
minv = lsqr(PPop_reg, datar.ravel(), **kwargs_solver)[0]
356356
else:
@@ -364,7 +364,7 @@ def PoststackInversion(
364364
# create regularized normal eqs. and solve them simultaneously
365365
PP = ncp.dot(PPop.A.T, PPop.A) + epsI * ncp.eye(nt0, dtype=PPop.A.dtype)
366366
datarn = PPop.A.T * datar.reshape(nt0, nspatprod)
367-
PPop_reg = MatrixMult(PP, dims=nspatprod)
367+
PPop_reg = MatrixMult(PP, otherdims=nspatprod)
368368
minv = get_lstsq(data)(PPop_reg.A, datarn.ravel(), **kwargs_solver)[0]
369369
else:
370370
# solve unregularized normal equations simultaneously with lop

pylops/avo/prestack.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def PrestackLinearModelling(
182182

183183
# Combine operators
184184
M = ncp.dot(C, ncp.dot(G, D))
185-
Preop = MatrixMult(M, dims=spatdims, dtype=dtype)
185+
Preop = MatrixMult(M, otherdims=spatdims, dtype=dtype)
186186

187187
else:
188188
# Create wavelet operator
@@ -504,7 +504,7 @@ def PrestackInversion(
504504
]
505505
if explicit:
506506
PPop = MatrixMult(
507-
np.vstack([Op.A for Op in PPop]), dims=nspat, dtype=PPop[0].A.dtype
507+
np.vstack([Op.A for Op in PPop]), otherdims=nspat, dtype=PPop[0].A.dtype
508508
)
509509
else:
510510
PPop = VStack(PPop)
@@ -560,7 +560,7 @@ def PrestackInversion(
560560
minv = get_lstsq(data)(PP, datarn, **kwargs_solver)[0]
561561
else:
562562
# solve regularized normal equations simultaneously
563-
PPop_reg = MatrixMult(PP, dims=nspatprod)
563+
PPop_reg = MatrixMult(PP, otherdims=nspatprod)
564564
if ncp == np:
565565
minv = lsqr(PPop_reg, datarn.ravel(), **kwargs_solver)[0]
566566
else:
@@ -574,7 +574,7 @@ def PrestackInversion(
574574
# # create regularized normal eqs. and solve them simultaneously
575575
# PP = np.dot(PPop.A.T, PPop.A) + epsI * np.eye(nt0*nm)
576576
# datarn = PPop.A.T * datar.reshape(nt0*ntheta, nspatprod)
577-
# PPop_reg = MatrixMult(PP, dims=ntheta*nspatprod)
577+
# PPop_reg = MatrixMult(PP, otherdims=ntheta*nspatprod)
578578
# minv = lstsq(PPop_reg, datarn.ravel(), **kwargs_solver)[0]
579579
else:
580580
# solve unregularized normal equations simultaneously with lop

pylops/basicoperators/MatrixMult.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from scipy.sparse.linalg import inv
66

77
from pylops import LinearOperator
8+
from pylops.utils._internal import _value_or_list_like_to_array
89
from pylops.utils.backend import get_array_module
910

1011
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
@@ -20,7 +21,7 @@ class MatrixMult(LinearOperator):
2021
----------
2122
A : :obj:`numpy.ndarray` or :obj:`scipy.sparse` matrix
2223
Matrix.
23-
dims : :obj:`tuple`, optional
24+
otherdims : :obj:`tuple`, optional
2425
Number of samples for each other dimension of model
2526
(model/data will be reshaped and ``A`` applied multiple times
2627
to each column of the model/data).
@@ -29,6 +30,10 @@ class MatrixMult(LinearOperator):
2930
3031
Attributes
3132
----------
33+
dimsd : :obj:`tuple`
34+
Shape of the array after the forward, but before linearization.
35+
36+
For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``.
3237
shape : :obj:`tuple`
3338
Operator shape
3439
explicit : :obj:`bool`
@@ -39,30 +44,29 @@ class MatrixMult(LinearOperator):
3944
4045
"""
4146

42-
def __init__(self, A, dims=None, dtype="float64"):
47+
def __init__(self, A, otherdims=None, dtype="float64"):
4348
ncp = get_array_module(A)
4449
self.A = A
4550
if isinstance(A, ncp.ndarray):
4651
self.complex = np.iscomplexobj(A)
4752
else:
4853
self.complex = np.iscomplexobj(A.data)
49-
if dims is None:
54+
if otherdims is None:
55+
self.dims, self.dimsd = A.shape[1], A.shape[0]
5056
self.reshape = False
5157
self.shape = A.shape
5258
self.explicit = True
5359
else:
54-
if isinstance(dims, int):
55-
dims = (dims,)
60+
otherdims = _value_or_list_like_to_array(otherdims)
61+
self.otherdims = np.array(otherdims, dtype=int)
62+
self.dims, self.dimsd = np.insert(
63+
self.otherdims, 0, self.A.shape[1]
64+
), np.insert(self.otherdims, 0, self.A.shape[0])
65+
self.dimsflatten, self.dimsdflatten = np.insert(
66+
[np.prod(self.otherdims)], 0, self.A.shape[1]
67+
), np.insert([np.prod(self.otherdims)], 0, self.A.shape[0])
5668
self.reshape = True
57-
self.dims = np.array(dims, dtype=int)
58-
self.reshapedims = [
59-
np.insert([np.prod(self.dims)], 0, self.A.shape[1]),
60-
np.insert([np.prod(self.dims)], 0, self.A.shape[0]),
61-
]
62-
self.shape = (
63-
A.shape[0] * np.prod(self.dims),
64-
A.shape[1] * np.prod(self.dims),
65-
)
69+
self.shape = (np.prod(self.dimsd), np.prod(self.dims))
6670
self.explicit = False
6771
self.dtype = np.dtype(dtype)
6872
# Check dtype for correctness (upcast to complex when A is complex)
@@ -75,7 +79,7 @@ def __init__(self, A, dims=None, dtype="float64"):
7579
def _matvec(self, x):
7680
ncp = get_array_module(x)
7781
if self.reshape:
78-
x = ncp.reshape(x, self.reshapedims[0])
82+
x = ncp.reshape(x, self.dimsflatten)
7983
y = self.A.dot(x)
8084
if self.reshape:
8185
return y.ravel()
@@ -85,7 +89,7 @@ def _matvec(self, x):
8589
def _rmatvec(self, x):
8690
ncp = get_array_module(x)
8791
if self.reshape:
88-
x = ncp.reshape(x, self.reshapedims[1])
92+
x = ncp.reshape(x, self.dimsdflatten)
8993
if self.complex:
9094
y = (self.A.T.dot(x.conj())).conj()
9195
else:

pytests/test_basicoperators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_MatrixMult_repeated(par):
127127
G = np.random.normal(0, 10, (par["ny"], par["nx"])).astype("float32") + par[
128128
"imag"
129129
] * np.random.normal(0, 10, (par["ny"], par["nx"])).astype("float32")
130-
Gop = MatrixMult(G, dims=5, dtype=par["dtype"])
130+
Gop = MatrixMult(G, otherdims=5, dtype=par["dtype"])
131131
assert dottest(
132132
Gop, par["ny"] * 5, par["nx"] * 5, complexflag=0 if par["imag"] == 1 else 3
133133
)

0 commit comments

Comments
 (0)