Skip to content

Commit 65cff84

Browse files
committed
Updated code for torchoperator
1 parent af1e822 commit 65cff84

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

pylops/torchoperator.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"TorchOperator",
33
]
44

5-
from typing import Optional
5+
from typing import Optional, Callable
66

77
import numpy as np
88

@@ -72,21 +72,26 @@ def __init__(
7272
self.transpb = np.roll(np.arange(2 if flatten else len(self.dims) + 1), 1)
7373
self.batch = batch
7474
self.Op = Op
75-
self.matvec = self._matvec
76-
self.rmatvec = self._rmatvec
75+
self._register_torchop()
7776
self.Top = _TorchOperator.apply
7877

79-
def _matvec(self, x: NDArray) -> NDArray:
78+
def _register_torchop(self):
79+
# choose _matvec and _rmatvec
80+
self._hmatvec: Callable
81+
self._hrmatvec: Callable
82+
8083
if not self.batch:
81-
return self.Op @ x
84+
self._hmatvec = lambda x: self.Op @ x
85+
self._hrmatvec = lambda x: self.Op.H @ x
8286
else:
83-
return (self.Op @ x.transpose(self.transpf)).transpose(self.transpb)
87+
self._hmatvec = lambda x: (self.Op @ x.transpose(self.transpf)).transpose(self.transpb)
88+
self._hrmatvec = lambda x: (self.Op.H @ x.transpose(self.transpf)).transpose(self.transpb)
89+
90+
def _matvec(self, x: NDArray) -> NDArray:
91+
return self._hmatvec(x)
8492

8593
def _rmatvec(self, x: NDArray) -> NDArray:
86-
if not self.batch:
87-
return self.Op.H @ x
88-
else:
89-
return (self.Op.H @ x.transpose(self.transpf)).transpose(self.transpb)
94+
return self._hrmatvec(x)
9095

9196
def apply(self, x: TensorTypeLike) -> TensorTypeLike:
9297
"""Apply forward pass to input vector
@@ -102,4 +107,4 @@ def apply(self, x: TensorTypeLike) -> TensorTypeLike:
102107
Output array resulting from the application of the operator to ``x``.
103108
104109
"""
105-
return self.Top(x, self.matvec, self.rmatvec, self.device, self.devicetorch)
110+
return self.Top(x, self._hmatvec, self._hrmatvec, self.device, self.devicetorch)

0 commit comments

Comments
 (0)