22 "TorchOperator" ,
33]
44
5- from typing import Optional
5+ from typing import Optional , Callable
66
77import numpy as np
88
@@ -72,21 +72,26 @@ def __init__(
7272 self .transpb = np .roll (np .arange (2 if flatten else len (self .dims ) + 1 ), 1 )
7373 self .batch = batch
7474 self .Op = Op
75- self .matvec = self ._matvec
76- self .rmatvec = self ._rmatvec
75+ self ._register_torchop ()
7776 self .Top = _TorchOperator .apply
7877
79- def _matvec (self , x : NDArray ) -> NDArray :
78+ def _register_torchop (self ):
79+ # choose _matvec and _rmatvec
80+ self ._hmatvec : Callable
81+ self ._hrmatvec : Callable
82+
8083 if not self .batch :
81- return self .Op @ x
84+ self ._hmatvec = lambda x : self .Op @ x
85+ self ._hrmatvec = lambda x : self .Op .H @ x
8286 else :
83- return (self .Op @ x .transpose (self .transpf )).transpose (self .transpb )
87+ self ._hmatvec = lambda x : (self .Op @ x .transpose (self .transpf )).transpose (self .transpb )
88+ self ._hrmatvec = lambda x : (self .Op .H @ x .transpose (self .transpf )).transpose (self .transpb )
89+
90+ def _matvec (self , x : NDArray ) -> NDArray :
91+ return self ._hmatvec (x )
8492
8593 def _rmatvec (self , x : NDArray ) -> NDArray :
86- if not self .batch :
87- return self .Op .H @ x
88- else :
89- return (self .Op .H @ x .transpose (self .transpf )).transpose (self .transpb )
94+ return self ._hrmatvec (x )
9095
9196 def apply (self , x : TensorTypeLike ) -> TensorTypeLike :
9297 """Apply forward pass to input vector
@@ -102,4 +107,4 @@ def apply(self, x: TensorTypeLike) -> TensorTypeLike:
102107 Output array resulting from the application of the operator to ``x``.
103108
104109 """
105- return self .Top (x , self .matvec , self .rmatvec , self .device , self .devicetorch )
110+ return self .Top (x , self ._hmatvec , self ._hrmatvec , self .device , self .devicetorch )
0 commit comments