Skip to content

Commit d5c6252

Browse files
authored
Merge pull request #489 from rohanbabbar04/abc
Abstraction to `_matvec` and `_rmatvec`
2 parents 5af878d + fee227f commit d5c6252

File tree

14 files changed

+77
-54
lines changed

14 files changed

+77
-54
lines changed

examples/plot_sliding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@
145145
# This is because we have not inverted our operator but simply applied
146146
# the adjoint to estimate the representation of the input data in the Radon
147147
# domain. We can do better if we use the inverse instead.
148-
radoninv = pylops.LinearOperator(Slid, explicit=False).div(data.ravel(), niter=10)
148+
radoninv = Slid.div(data.ravel(), niter=10)
149149
reconstructed_datainv = Slid * radoninv.ravel()
150150

151151
radoninv = radoninv.reshape(dims)
@@ -288,7 +288,7 @@
288288

289289
reconstructed_data = Slid * radon
290290

291-
radoninv = pylops.LinearOperator(Slid, explicit=False).div(data.ravel(), niter=10)
291+
radoninv = Slid.div(data.ravel(), niter=10)
292292
radoninv = radoninv.reshape(Slid.dims)
293293
reconstructed_datainv = Slid * radoninv
294294

pylops/basicoperators/directionalderivative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def FirstDirectionalDerivative(
7777
else:
7878
Dop = Diagonal(v.ravel(), dtype=dtype)
7979
Sop = Sum(dims=[len(dims)] + list(dims), axis=0, dtype=dtype)
80-
ddop = LinearOperator(Sop * Dop * Gop)
80+
ddop = Sop * Dop * Gop
8181
ddop.dims = ddop.dimsd = dims
8282
ddop.sampling = sampling
8383
return ddop
@@ -136,7 +136,7 @@ def SecondDirectionalDerivative(
136136
in the literature.
137137
"""
138138
Dop = FirstDirectionalDerivative(dims, v, sampling=sampling, edge=edge, dtype=dtype)
139-
ddop = LinearOperator(-Dop.H * Dop)
139+
ddop = -Dop.H * Dop
140140
ddop.dims = ddop.dimsd = dims
141141
ddop.sampling = sampling
142142
return ddop

pylops/basicoperators/firstderivative.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,28 +108,34 @@ def _register_multiplications(
108108
order: int,
109109
) -> None:
110110
# choose _matvec and _rmatvec kind
111-
self._matvec: Callable
112-
self._rmatvec: Callable
111+
self._hmatvec: Callable
112+
self._hrmatvec: Callable
113113
if kind == "forward":
114-
self._matvec = self._matvec_forward
115-
self._rmatvec = self._rmatvec_forward
114+
self._hmatvec = self._matvec_forward
115+
self._hrmatvec = self._rmatvec_forward
116116
elif kind == "centered":
117117
if order == 3:
118-
self._matvec = self._matvec_centered3
119-
self._rmatvec = self._rmatvec_centered3
118+
self._hmatvec = self._matvec_centered3
119+
self._hrmatvec = self._rmatvec_centered3
120120
elif order == 5:
121-
self._matvec = self._matvec_centered5
122-
self._rmatvec = self._rmatvec_centered5
121+
self._hmatvec = self._matvec_centered5
122+
self._hrmatvec = self._rmatvec_centered5
123123
else:
124124
raise NotImplementedError("'order' must be '3, or '5'")
125125
elif kind == "backward":
126-
self._matvec = self._matvec_backward
127-
self._rmatvec = self._rmatvec_backward
126+
self._hmatvec = self._matvec_backward
127+
self._hrmatvec = self._rmatvec_backward
128128
else:
129129
raise NotImplementedError(
130130
"'kind' must be 'forward', 'centered', or 'backward'"
131131
)
132132

133+
def _matvec(self, x: NDArray) -> NDArray:
134+
return self._hmatvec(x)
135+
136+
def _rmatvec(self, x: NDArray) -> NDArray:
137+
return self._hrmatvec(x)
138+
133139
@reshaped(swapaxis=True)
134140
def _matvec_forward(self, x: NDArray) -> NDArray:
135141
ncp = get_array_module(x)

pylops/basicoperators/secondderivative.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,28 @@ def _register_multiplications(
9797
kind: str,
9898
) -> None:
9999
# choose _matvec and _rmatvec kind
100-
self._matvec: Callable
101-
self._rmatvec: Callable
100+
self._hmatvec: Callable
101+
self._hrmatvec: Callable
102102
if kind == "forward":
103-
self._matvec = self._matvec_forward
104-
self._rmatvec = self._rmatvec_forward
103+
self._hmatvec = self._matvec_forward
104+
self._hrmatvec = self._rmatvec_forward
105105
elif kind == "centered":
106-
self._matvec = self._matvec_centered
107-
self._rmatvec = self._rmatvec_centered
106+
self._hmatvec = self._matvec_centered
107+
self._hrmatvec = self._rmatvec_centered
108108
elif kind == "backward":
109-
self._matvec = self._matvec_backward
110-
self._rmatvec = self._rmatvec_backward
109+
self._hmatvec = self._matvec_backward
110+
self._hrmatvec = self._rmatvec_backward
111111
else:
112112
raise NotImplementedError(
113113
"'kind' must be 'forward', 'centered' or 'backward'"
114114
)
115115

116+
def _matvec(self, x: NDArray) -> NDArray:
117+
return self._hmatvec(x)
118+
119+
def _rmatvec(self, x: NDArray) -> NDArray:
120+
return self._hrmatvec(x)
121+
116122
@reshaped(swapaxis=True)
117123
def _matvec_forward(self, x: NDArray) -> NDArray:
118124
ncp = get_array_module(x)

pylops/linearoperator.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
"aslinearoperator",
66
]
77

8-
98
import logging
9+
from abc import ABC, abstractmethod
1010

1111
import numpy as np
1212
import scipy as sp
@@ -40,7 +40,21 @@
4040
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
4141

4242

43-
class LinearOperator:
43+
class _LinearOperator(ABC):
44+
"""Meta-class for Linear operator"""
45+
46+
@abstractmethod
47+
def _matvec(self, x: NDArray) -> NDArray:
48+
"""Matrix-vector multiplication handler."""
49+
pass
50+
51+
@abstractmethod
52+
def _rmatvec(self, x: NDArray) -> NDArray:
53+
"""Matrix-vector adjoint multiplication handler."""
54+
pass
55+
56+
57+
class LinearOperator(_LinearOperator):
4458
"""Common interface for performing matrix-vector products.
4559
4660
This class acts as an abstract interface between matrix-like
@@ -567,14 +581,14 @@ def dot(self, x: NDArray) -> NDArray:
567581
# cast x to pylops linear operator if not already (this is done
568582
# to allow mixing pylops and scipy operators)
569583
Opx = aslinearoperator(x)
570-
Op = LinearOperator(Op=_ProductLinearOperator(self, Opx))
584+
Op = _ProductLinearOperator(self, Opx)
571585
self._copy_attributes(Op, exclude=["dims", "explicit", "name"])
572586
Op.clinear = Op.clinear and Opx.clinear
573587
Op.explicit = False
574588
Op.dims = Opx.dims
575589
return Op
576590
elif np.isscalar(x):
577-
Op = LinearOperator(Op=_ScaledLinearOperator(self, x))
591+
Op = _ScaledLinearOperator(self, x)
578592
self._copy_attributes(
579593
Op,
580594
exclude=["explicit", "name"],

pylops/optimization/cls_sparsity.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,6 @@ def setup(
762762
Display setup log
763763
764764
"""
765-
self.Op = LinearOperator(self.Op)
766765
self.y = y
767766
self.niter_outer = niter_outer
768767
self.niter_inner = niter_inner
@@ -1240,10 +1239,8 @@ def setup(
12401239
if alpha is not None:
12411240
self.alpha = alpha
12421241
elif not hasattr(self, "alpha"):
1243-
if not isinstance(self.Op, LinearOperator):
1244-
self.Op = LinearOperator(self.Op, explicit=False)
12451242
# compute largest eigenvalues of Op^H * Op
1246-
Op1 = LinearOperator(self.Op.H * self.Op, explicit=False)
1243+
Op1 = self.Op.H * self.Op
12471244
if get_module_name(self.ncp) == "numpy":
12481245
maxeig: float = np.abs(
12491246
Op1.eigs(

pylops/signalprocessing/patch2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from pylops import LinearOperator, aslinearoperator
11+
from pylops import LinearOperator
1212
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
1313
from pylops.signalprocessing.sliding2d import _slidingsteps
1414
from pylops.utils.tapers import taper2d
@@ -264,7 +264,7 @@ def Patch2D(
264264
for win_in, win_end in zip(dwin0_ins, dwin0_ends)
265265
]
266266
)
267-
Pop = aslinearoperator(combining0 * combining1 * OOp)
267+
Pop = LinearOperator(combining0 * combining1 * OOp)
268268
Pop.dims, Pop.dimsd = (
269269
nwins0,
270270
nwins1,

pylops/signalprocessing/patch3d.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from pylops import LinearOperator, aslinearoperator
11+
from pylops import LinearOperator
1212
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
1313
from pylops.signalprocessing.sliding2d import _slidingsteps
1414
from pylops.utils.tapers import tapernd
@@ -443,7 +443,7 @@ def Patch3D(
443443
]
444444
)
445445

446-
Pop = aslinearoperator(combining0 * combining1 * combining2 * OOp)
446+
Pop = LinearOperator(combining0 * combining1 * combining2 * OOp)
447447
Pop.dims, Pop.dimsd = (
448448
nwins0,
449449
nwins1,
@@ -452,6 +452,5 @@ def Patch3D(
452452
int(dims[1] // nwins1),
453453
int(dims[2] // nwins2),
454454
), dimsd
455-
456455
Pop.name = name
457456
return Pop

pylops/signalprocessing/sliding1d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from typing import Tuple, Union
88

9-
from pylops import LinearOperator, aslinearoperator
9+
from pylops import LinearOperator
1010
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
1111
from pylops.signalprocessing.sliding2d import _slidingsteps
1212
from pylops.utils._internal import _value_or_sized_to_tuple
@@ -180,7 +180,7 @@ def Sliding1D(
180180
for win_in, win_end in zip(dwin_ins, dwin_ends)
181181
]
182182
)
183-
Sop = aslinearoperator(combining * OOp)
183+
Sop = LinearOperator(combining * OOp)
184184
Sop.dims, Sop.dimsd = (nwins, int(dim[0] // nwins)), dimd
185185
Sop.name = name
186186
return Sop

pylops/signalprocessing/sliding2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from pylops import LinearOperator, aslinearoperator
11+
from pylops import LinearOperator
1212
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
1313
from pylops.utils.tapers import taper2d
1414
from pylops.utils.typing import InputDimsLike, NDArray
@@ -214,7 +214,7 @@ def Sliding2D(
214214
for win_in, win_end in zip(dwin_ins, dwin_ends)
215215
]
216216
)
217-
Sop = aslinearoperator(combining * OOp)
217+
Sop = LinearOperator(combining * OOp)
218218
Sop.dims, Sop.dimsd = (nwins, int(dims[0] // nwins), dims[1]), dimsd
219219
Sop.name = name
220220
return Sop

0 commit comments

Comments
 (0)