55from scipy .sparse .linalg import inv
66
77from pylops import LinearOperator
8+ from pylops .utils ._internal import _value_or_list_like_to_array
89from pylops .utils .backend import get_array_module
910
1011logging .basicConfig (format = "%(levelname)s: %(message)s" , level = logging .WARNING )
@@ -20,7 +21,7 @@ class MatrixMult(LinearOperator):
2021 ----------
2122 A : :obj:`numpy.ndarray` or :obj:`scipy.sparse` matrix
2223 Matrix.
23- dims : :obj:`tuple`, optional
24+ otherdims : :obj:`tuple`, optional
2425 Number of samples for each other dimension of model
2526 (model/data will be reshaped and ``A`` applied multiple times
2627 to each column of the model/data).
@@ -29,6 +30,10 @@ class MatrixMult(LinearOperator):
2930
3031 Attributes
3132 ----------
33+ dimsd : :obj:`tuple`
34+ Shape of the array after the forward, but before linearization.
35+
36+ For example, ``y_reshaped = (Op * x.ravel()).reshape(Op.dimsd)``.
3237 shape : :obj:`tuple`
3338 Operator shape
3439 explicit : :obj:`bool`
@@ -39,30 +44,29 @@ class MatrixMult(LinearOperator):
3944
4045 """
4146
42- def __init__ (self , A , dims = None , dtype = "float64" ):
47+ def __init__ (self , A , otherdims = None , dtype = "float64" ):
4348 ncp = get_array_module (A )
4449 self .A = A
4550 if isinstance (A , ncp .ndarray ):
4651 self .complex = np .iscomplexobj (A )
4752 else :
4853 self .complex = np .iscomplexobj (A .data )
49- if dims is None :
54+ if otherdims is None :
55+ self .dims , self .dimsd = A .shape [1 ], A .shape [0 ]
5056 self .reshape = False
5157 self .shape = A .shape
5258 self .explicit = True
5359 else :
54- if isinstance (dims , int ):
55- dims = (dims ,)
60+ otherdims = _value_or_list_like_to_array (otherdims )
61+ self .otherdims = np .array (otherdims , dtype = int )
62+ self .dims , self .dimsd = np .insert (
63+ self .otherdims , 0 , self .A .shape [1 ]
64+ ), np .insert (self .otherdims , 0 , self .A .shape [0 ])
65+ self .dimsflatten , self .dimsdflatten = np .insert (
66+ [np .prod (self .otherdims )], 0 , self .A .shape [1 ]
67+ ), np .insert ([np .prod (self .otherdims )], 0 , self .A .shape [0 ])
5668 self .reshape = True
57- self .dims = np .array (dims , dtype = int )
58- self .reshapedims = [
59- np .insert ([np .prod (self .dims )], 0 , self .A .shape [1 ]),
60- np .insert ([np .prod (self .dims )], 0 , self .A .shape [0 ]),
61- ]
62- self .shape = (
63- A .shape [0 ] * np .prod (self .dims ),
64- A .shape [1 ] * np .prod (self .dims ),
65- )
69+ self .shape = (np .prod (self .dimsd ), np .prod (self .dims ))
6670 self .explicit = False
6771 self .dtype = np .dtype (dtype )
6872 # Check dtype for correctness (upcast to complex when A is complex)
@@ -75,7 +79,7 @@ def __init__(self, A, dims=None, dtype="float64"):
7579 def _matvec (self , x ):
7680 ncp = get_array_module (x )
7781 if self .reshape :
78- x = ncp .reshape (x , self .reshapedims [ 0 ] )
82+ x = ncp .reshape (x , self .dimsflatten )
7983 y = self .A .dot (x )
8084 if self .reshape :
8185 return y .ravel ()
@@ -85,7 +89,7 @@ def _matvec(self, x):
8589 def _rmatvec (self , x ):
8690 ncp = get_array_module (x )
8791 if self .reshape :
88- x = ncp .reshape (x , self .reshapedims [ 1 ] )
92+ x = ncp .reshape (x , self .dimsdflatten )
8993 if self .complex :
9094 y = (self .A .T .dot (x .conj ())).conj ()
9195 else :
0 commit comments