Skip to content

Commit 46e55a0

Browse files
committed
minor: small improvements to allow adding ABC to LinearOperator
1 parent b2c0790 commit 46e55a0

File tree

10 files changed

+75
-67
lines changed

10 files changed

+75
-67
lines changed

examples/plot_sliding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@
145145
# This is because we have not inverted our operator but simply applied
146146
# the adjoint to estimate the representation of the input data in the Radon
147147
# domain. We can do better if we use the inverse instead.
148-
radoninv = pylops.aslinearoperator(Slid).div(data.ravel(), niter=10)
148+
radoninv = Slid.div(data.ravel(), niter=10)
149149
reconstructed_datainv = Slid * radoninv.ravel()
150150

151151
radoninv = radoninv.reshape(dims)
@@ -288,7 +288,7 @@
288288

289289
reconstructed_data = Slid * radon
290290

291-
radoninv = pylops.aslinearoperator(Slid).div(data.ravel(), niter=10)
291+
radoninv = Slid.div(data.ravel(), niter=10)
292292
radoninv = radoninv.reshape(Slid.dims)
293293
reconstructed_datainv = Slid * radoninv
294294

pylops/basicoperators/firstderivative.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["FirstDerivative"]
22

3-
from typing import Union
3+
from typing import Callable, Union
44

55
import numpy as np
66
from numpy.core.multiarray import normalize_axis_index
@@ -100,40 +100,41 @@ def __init__(
100100
self.kind = kind
101101
self.edge = edge
102102
self.order = order
103+
self._register_multiplications(self.kind, self.order)
103104

104-
def _matvec(self, x: NDArray) -> NDArray:
105-
if self.kind == "forward":
106-
return self._matvec_forward(x)
107-
elif self.kind == "backward":
108-
return self._matvec_backward(x)
109-
elif self.kind == "centered":
110-
if self.order == 3:
111-
return self._matvec_centered3(x)
112-
elif self.order == 5:
113-
return self._matvec_centered5(x)
105+
def _register_multiplications(
106+
self,
107+
kind: str,
108+
order: int,
109+
) -> None:
110+
# choose _matvec and _rmatvec kind
111+
self._hmatvec: Callable
112+
self._hrmatvec: Callable
113+
if kind == "forward":
114+
self._hmatvec = self._matvec_forward
115+
self._hrmatvec = self._rmatvec_forward
116+
elif kind == "centered":
117+
if order == 3:
118+
self._hmatvec = self._matvec_centered3
119+
self._hrmatvec = self._rmatvec_centered3
120+
elif order == 5:
121+
self._hmatvec = self._matvec_centered5
122+
self._hrmatvec = self._rmatvec_centered5
114123
else:
115124
raise NotImplementedError("'order' must be '3, or '5'")
125+
elif kind == "backward":
126+
self._hmatvec = self._matvec_backward
127+
self._hrmatvec = self._rmatvec_backward
116128
else:
117129
raise NotImplementedError(
118-
"'kind' must be 'forward', 'centered' or 'backward'"
130+
"'kind' must be 'forward', 'centered', or 'backward'"
119131
)
120132

133+
def _matvec(self, x: NDArray) -> NDArray:
134+
return self._hmatvec(x)
135+
121136
def _rmatvec(self, x: NDArray) -> NDArray:
122-
if self.kind == "forward":
123-
return self._rmatvec_forward(x)
124-
elif self.kind == "backward":
125-
return self._rmatvec_backward(x)
126-
elif self.kind == "centered":
127-
if self.order == 3:
128-
return self._rmatvec_centered3(x)
129-
elif self.order == 5:
130-
return self._rmatvec_centered5(x)
131-
else:
132-
raise NotImplementedError("'order' must be '3, or '5'")
133-
else:
134-
raise NotImplementedError(
135-
"'kind' must be 'forward', 'centered' or 'backward'"
136-
)
137+
return self._hrmatvec(x)
137138

138139
@reshaped(swapaxis=True)
139140
def _matvec_forward(self, x: NDArray) -> NDArray:

pylops/basicoperators/secondderivative.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["SecondDerivative"]
22

3-
from typing import Union
3+
from typing import Callable, Union
44

55
import numpy as np
66
from numpy.core.multiarray import normalize_axis_index
@@ -90,30 +90,34 @@ def __init__(
9090
self.sampling = sampling
9191
self.kind = kind
9292
self.edge = edge
93+
self._register_multiplications(self.kind)
9394

94-
def _matvec(self, x: NDArray) -> NDArray:
95-
if self.kind == "forward":
96-
return self._matvec_forward(x)
97-
elif self.kind == "backward":
98-
return self._matvec_backward(x)
99-
elif self.kind == "centered":
100-
return self._matvec_centered(x)
95+
def _register_multiplications(
96+
self,
97+
kind: str,
98+
) -> None:
99+
# choose _matvec and _rmatvec kind
100+
self._hmatvec: Callable
101+
self._hrmatvec: Callable
102+
if kind == "forward":
103+
self._hmatvec = self._matvec_forward
104+
self._hrmatvec = self._rmatvec_forward
105+
elif kind == "centered":
106+
self._hmatvec = self._matvec_centered
107+
self._hrmatvec = self._rmatvec_centered
108+
elif kind == "backward":
109+
self._hmatvec = self._matvec_backward
110+
self._hrmatvec = self._rmatvec_backward
101111
else:
102112
raise NotImplementedError(
103113
"'kind' must be 'forward', 'centered' or 'backward'"
104114
)
105115

116+
def _matvec(self, x: NDArray) -> NDArray:
117+
return self._hmatvec(x)
118+
106119
def _rmatvec(self, x: NDArray) -> NDArray:
107-
if self.kind == "forward":
108-
return self._rmatvec_forward(x)
109-
elif self.kind == "backward":
110-
return self._rmatvec_backward(x)
111-
elif self.kind == "centered":
112-
return self._rmatvec_centered(x)
113-
else:
114-
raise NotImplementedError(
115-
"'kind' must be 'forward', 'centered' or 'backward'"
116-
)
120+
return self._hrmatvec(x)
117121

118122
@reshaped(swapaxis=True)
119123
def _matvec_forward(self, x: NDArray) -> NDArray:

pylops/linearoperator.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
"aslinearoperator",
66
]
77

8-
from abc import ABCMeta, abstractmethod
9-
108
import logging
9+
from abc import ABC, abstractmethod
1110

1211
import numpy as np
1312
import scipy as sp
@@ -41,7 +40,21 @@
4140
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
4241

4342

44-
class LinearOperator(metaclass=ABCMeta):
43+
class _LinearOperator(ABC):
44+
"""Meta-class for Linear operator"""
45+
46+
@abstractmethod
47+
def _matvec(self, x: NDArray) -> NDArray:
48+
"""Matrix-vector multiplication handler."""
49+
pass
50+
51+
@abstractmethod
52+
def _rmatvec(self, x: NDArray) -> NDArray:
53+
"""Matrix-vector adjoint multiplication handler."""
54+
pass
55+
56+
57+
class LinearOperator(_LinearOperator):
4558
"""Common interface for performing matrix-vector products.
4659
4760
This class acts as an abstract interface between matrix-like
@@ -369,13 +382,11 @@ def _copy_attributes(
369382
if hasattr(self, attr):
370383
setattr(dest, attr, getattr(self, attr))
371384

372-
@abstractmethod
373385
def _matvec(self, x: NDArray) -> NDArray:
374386
"""Matrix-vector multiplication handler."""
375387
if self.Op is not None:
376388
return self.Op._matvec(x)
377389

378-
@abstractmethod
379390
def _rmatvec(self, x: NDArray) -> NDArray:
380391
"""Matrix-vector adjoint multiplication handler."""
381392
if self.Op is not None:
@@ -1482,3 +1493,5 @@ def aslinearoperator(Op: Union[spLinearOperator, LinearOperator]) -> LinearOpera
14821493
"""
14831494
if isinstance(Op, LinearOperator):
14841495
return Op
1496+
else:
1497+
return LinearOperator(Op)

pylops/signalprocessing/patch2d.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,12 @@ def Patch2D(
264264
for win_in, win_end in zip(dwin0_ins, dwin0_ends)
265265
]
266266
)
267-
Pop = aslinearoperator(combining0 * combining1 * OOp)
267+
Pop = combining0 * combining1 * OOp
268268
Pop.dims, Pop.dimsd = (
269269
nwins0,
270270
nwins1,
271271
int(dims[0] // nwins0),
272272
int(dims[1] // nwins1),
273273
), dimsd
274-
OOp.dims = Pop.dims
275-
combining0.dimsd = Pop.dimsd
276274
Pop.name = name
277275
return Pop

pylops/signalprocessing/patch3d.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def Patch3D(
443443
]
444444
)
445445

446-
Pop = aslinearoperator(combining0 * combining1 * combining2 * OOp)
446+
Pop = combining0 * combining1 * combining2 * OOp
447447
Pop.dims, Pop.dimsd = (
448448
nwins0,
449449
nwins1,
@@ -452,7 +452,5 @@ def Patch3D(
452452
int(dims[1] // nwins1),
453453
int(dims[2] // nwins2),
454454
), dimsd
455-
OOp.dims = Pop.dims
456-
combining0.dimsd = Pop.dimsd
457455
Pop.name = name
458456
return Pop

pylops/signalprocessing/sliding1d.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,7 @@ def Sliding1D(
180180
for win_in, win_end in zip(dwin_ins, dwin_ends)
181181
]
182182
)
183-
Sop = aslinearoperator(combining * OOp)
183+
Sop = combining * OOp
184184
Sop.dims, Sop.dimsd = (nwins, int(dim[0] // nwins)), dimd
185185
Sop.name = name
186-
OOp.dims = Sop.dims
187-
combining.dimsd = Sop.dimsd
188186
return Sop

pylops/signalprocessing/sliding2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,5 @@ def Sliding2D(
216216
)
217217
Sop = aslinearoperator(combining * OOp)
218218
Sop.dims, Sop.dimsd = (nwins, int(dims[0] // nwins), dims[1]), dimsd
219-
OOp.dims = Sop.dims
220-
combining.dimsd = Sop.dimsd
221219
Sop.name = name
222220
return Sop

pylops/signalprocessing/sliding3d.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def Sliding3D(
215215
for win_in, win_end in zip(dwin0_ins, dwin0_ends)
216216
]
217217
)
218-
Sop = aslinearoperator(combining0 * combining1 * OOp)
218+
Sop = combining0 * combining1 * OOp
219219
Sop.dims, Sop.dimsd = (
220220
nwins0,
221221
nwins1,
@@ -224,6 +224,4 @@ def Sliding3D(
224224
dims[2],
225225
), dimsd
226226
Sop.name = name
227-
combining0.dimsd = Sop.dimsd
228-
OOp.dims = Sop.dims
229227
return Sop

pytests/test_linearoperator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_scaled(par):
7171
def test_scipyop(par):
7272
"""Verify interaction between pylops and scipy Linear operators"""
7373

74-
class spDiag(LinearOperator):
74+
class spDiag(spLinearOperator):
7575
def __init__(self, x):
7676
self.x = x
7777
self.shape = (len(x), len(x))

0 commit comments

Comments
 (0)