Skip to content

Commit 4428dda

Browse files
committed
feature: remove LinearOperator inheritance from TorchOperator
1 parent 363873a commit 4428dda

File tree

1 file changed

+15
-24
lines changed

1 file changed

+15
-24
lines changed

pylops/torchoperator.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
"TorchOperator",
33
]
44

5-
from typing import Optional, Callable
5+
from typing import Optional
66

77
import numpy as np
88

99
from pylops import LinearOperator
10-
from pylops.utils import deps, NDArray
10+
from pylops.utils import deps
1111

1212
if deps.torch_enabled:
1313
from pylops._torchoperator import _TorchOperator
@@ -20,7 +20,7 @@
2020
from pylops.utils.typing import TensorTypeLike
2121

2222

23-
class TorchOperator(LinearOperator):
23+
class TorchOperator:
2424
"""Wrap a PyLops operator into a Torch function.
2525
2626
This class can be used to wrap a pylops operator into a
@@ -63,33 +63,24 @@ def __init__(
6363
raise NotImplementedError(torch_message)
6464
self.device = device
6565
self.devicetorch = devicetorch
66-
super().__init__(
67-
dtype=np.dtype(Op.dtype), dims=Op.dims, dimsd=Op.dims, name=Op.name
68-
)
66+
self.dtype = np.dtype(Op.dtype)
67+
self.dims, self.dimsd = Op.dims, Op.dimsd
68+
self.name = Op.name
6969
# define transpose indices to bring batch to last dimension before applying
7070
# pylops forward and adjoint (this will call matmat and rmatmat)
7171
self.transpf = np.roll(np.arange(2 if flatten else len(self.dims) + 1), -1)
7272
self.transpb = np.roll(np.arange(2 if flatten else len(self.dims) + 1), 1)
73-
self.Op = Op
74-
self._register_torchop(batch)
75-
self.Top = _TorchOperator.apply
76-
77-
def _register_torchop(self, batch: bool):
78-
# choose _matvec and _rmatvec
79-
self.matvec: Callable
80-
self.rmatvec: Callable
8173
if not batch:
82-
self.matvec = lambda x: self.Op @ x
83-
self.rmatvec = lambda x: self.Op.H @ x
74+
self.matvec = lambda x: Op @ x
75+
self.rmatvec = lambda x: Op.H @ x
8476
else:
85-
self.matvec = lambda x: (self.Op @ x.transpose(self.transpf)).transpose(self.transpb)
86-
self.rmatvec = lambda x: (self.Op.H @ x.transpose(self.transpf)).transpose(self.transpb)
87-
88-
def _matvec(self, x: NDArray) -> NDArray:
89-
return self.matvec(x)
90-
91-
def _rmatvec(self, x: NDArray) -> NDArray:
92-
return self.rmatvec(x)
77+
self.matvec = lambda x: (Op @ x.transpose(self.transpf)).transpose(
78+
self.transpb
79+
)
80+
self.rmatvec = lambda x: (Op.H @ x.transpose(self.transpf)).transpose(
81+
self.transpb
82+
)
83+
self.Top = _TorchOperator.apply
9384

9485
def apply(self, x: TensorTypeLike) -> TensorTypeLike:
9586
"""Apply forward pass to input vector

0 commit comments

Comments
 (0)