33import numpy as np
44
55from pylops import LinearOperator
6+ from pylops .utils ._internal import _value_or_list_like_to_tuple
67from pylops .utils .backend import get_array_module , to_cupy_conditional
78
89
@@ -64,28 +65,26 @@ def __init__(self, diag, dims=None, axis=-1, dtype="float64"):
6465 self .diag = diag .ravel ()
6566 self .complex = True if ncp .iscomplexobj (self .diag ) else False
6667
67- if dims is None :
68- self .shape = (len (self .diag ), len (self .diag ))
69- self .dims = None
70- self .reshape = False
71- else :
72- diagdims = [1 ] * len (dims )
73- diagdims [axis ] = dims [axis ]
74- self .diag = self .diag .reshape (diagdims )
75- self .shape = (np .prod (dims ), np .prod (dims ))
76- self .dims = dims
77- self .reshape = True
68+ ncp = get_array_module (diag )
69+ self .diag = diag .ravel ()
70+ self .complex = True if ncp .iscomplexobj (self .diag ) else False
71+ self .dims = self .dimsd = (
72+ (len (self .diag ),) if dims is None else _value_or_list_like_to_tuple (dims )
73+ )
74+
75+ diagdims = np .ones_like (self .dims )
76+ diagdims [axis ] = self .dims [axis ]
77+ self .diag = self .diag .reshape (diagdims )
78+
79+ self .shape = (np .prod (self .dimsd ), np .prod (self .dims ))
7880 self .dtype = np .dtype (dtype )
7981 self .explicit = False
8082
8183 def _matvec (self , x ):
8284 if type (self .diag ) != type (x ):
8385 self .diag = to_cupy_conditional (x , self .diag )
84- if not self .reshape :
85- y = self .diag * x .ravel ()
86- else :
87- x = x .reshape (self .dims )
88- y = self .diag * x
86+ x = x .reshape (self .dims )
87+ y = self .diag * x
8988 return y .ravel ()
9089
9190 def _rmatvec (self , x ):
@@ -95,11 +94,8 @@ def _rmatvec(self, x):
9594 diagadj = self .diag .conj ()
9695 else :
9796 diagadj = self .diag
98- if not self .reshape :
99- y = diagadj * x .ravel ()
100- else :
101- x = x .reshape (self .dims )
102- y = diagadj * x
97+ x = x .reshape (self .dims )
98+ y = diagadj * x
10399 return y .ravel ()
104100
105101 def matrix (self ):
0 commit comments