|
2 | 2 | "TorchOperator", |
3 | 3 | ] |
4 | 4 |
|
5 | | -from typing import Optional, Callable |
| 5 | +from typing import Optional |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 |
|
9 | 9 | from pylops import LinearOperator |
10 | | -from pylops.utils import deps, NDArray |
| 10 | +from pylops.utils import deps |
11 | 11 |
|
12 | 12 | if deps.torch_enabled: |
13 | 13 | from pylops._torchoperator import _TorchOperator |
|
20 | 20 | from pylops.utils.typing import TensorTypeLike |
21 | 21 |
|
22 | 22 |
|
23 | | -class TorchOperator(LinearOperator): |
| 23 | +class TorchOperator: |
24 | 24 | """Wrap a PyLops operator into a Torch function. |
25 | 25 |
|
26 | 26 | This class can be used to wrap a pylops operator into a |
@@ -63,33 +63,24 @@ def __init__( |
63 | 63 | raise NotImplementedError(torch_message) |
64 | 64 | self.device = device |
65 | 65 | self.devicetorch = devicetorch |
66 | | - super().__init__( |
67 | | - dtype=np.dtype(Op.dtype), dims=Op.dims, dimsd=Op.dims, name=Op.name |
68 | | - ) |
| 66 | + self.dtype = np.dtype(Op.dtype) |
| 67 | + self.dims, self.dimsd = Op.dims, Op.dimsd |
| 68 | + self.name = Op.name |
69 | 69 | # define transpose indices to bring batch to last dimension before applying |
70 | 70 | # pylops forward and adjoint (this will call matmat and rmatmat) |
71 | 71 | self.transpf = np.roll(np.arange(2 if flatten else len(self.dims) + 1), -1) |
72 | 72 | self.transpb = np.roll(np.arange(2 if flatten else len(self.dims) + 1), 1) |
73 | | - self.Op = Op |
74 | | - self._register_torchop(batch) |
75 | | - self.Top = _TorchOperator.apply |
76 | | - |
77 | | - def _register_torchop(self, batch: bool): |
78 | | - # choose _matvec and _rmatvec |
79 | | - self.matvec: Callable |
80 | | - self.rmatvec: Callable |
81 | 73 | if not batch: |
82 | | - self.matvec = lambda x: self.Op @ x |
83 | | - self.rmatvec = lambda x: self.Op.H @ x |
| 74 | + self.matvec = lambda x: Op @ x |
| 75 | + self.rmatvec = lambda x: Op.H @ x |
84 | 76 | else: |
85 | | - self.matvec = lambda x: (self.Op @ x.transpose(self.transpf)).transpose(self.transpb) |
86 | | - self.rmatvec = lambda x: (self.Op.H @ x.transpose(self.transpf)).transpose(self.transpb) |
87 | | - |
88 | | - def _matvec(self, x: NDArray) -> NDArray: |
89 | | - return self.matvec(x) |
90 | | - |
91 | | - def _rmatvec(self, x: NDArray) -> NDArray: |
92 | | - return self.rmatvec(x) |
| 77 | + self.matvec = lambda x: (Op @ x.transpose(self.transpf)).transpose( |
| 78 | + self.transpb |
| 79 | + ) |
| 80 | + self.rmatvec = lambda x: (Op.H @ x.transpose(self.transpf)).transpose( |
| 81 | + self.transpb |
| 82 | + ) |
| 83 | + self.Top = _TorchOperator.apply |
93 | 84 |
|
94 | 85 | def apply(self, x: TensorTypeLike) -> TensorTypeLike: |
95 | 86 | """Apply forward pass to input vector |
|
0 commit comments