@@ -1242,23 +1242,14 @@ def _get_dtype(
12421242) -> DTypeLike :
12431243 if dtypes is None :
12441244 dtypes = []
1245- opdtypes = []
12461245 for obj in operators :
12471246 if obj is not None and hasattr (obj , "dtype" ):
1248- opdtypes .append (obj .dtype )
1249- return np .find_common_type ( opdtypes , dtypes )
1247+ dtypes .append (obj .dtype )
1248+ return np .result_type ( * dtypes )
12501249
12511250
12521251class _ScaledLinearOperator (LinearOperator ):
1253- """
1254- Sum Linear Operator
1255-
1256- Modified version of scipy _ScaledLinearOperator which uses a modified
1257- _get_dtype where the scalar and operator types are passed separately to
1258- np.find_common_type. Passing them together does lead to problems when using
1259- np.float32 operators which are cast to np.float64
1260-
1261- """
1252+ """Scaled Linear Operator"""
12621253
12631254 def __init__ (
12641255 self ,
@@ -1269,7 +1260,15 @@ def __init__(
12691260 raise ValueError ("LinearOperator expected as A" )
12701261 if not np .isscalar (alpha ):
12711262 raise ValueError ("scalar expected as alpha" )
1272- dtype = _get_dtype ([A ], [type (alpha )])
1263+ if isinstance (alpha , complex ) and not np .iscomplexobj (
1264+ np .ones (1 , dtype = A .dtype )
1265+ ):
1266+ # if the scalar is of complex type but not the operator, find out type
1267+ dtype = _get_dtype ([A ], [type (alpha )])
1268+ else :
1269+ # if both the scalar and operator are of real or complex type, use type
1270+ # of the operator
1271+ dtype = A .dtype
12731272 super (_ScaledLinearOperator , self ).__init__ (dtype = dtype , shape = A .shape )
12741273 self .args = (A , alpha )
12751274
@@ -1465,7 +1464,7 @@ def __init__(self, A: LinearOperator, p: int) -> None:
14651464 if not isintlike (p ) or p < 0 :
14661465 raise ValueError ("non-negative integer expected as p" )
14671466
1468- super (_PowerLinearOperator , self ).__init__ (dtype = _get_dtype ([ A ]) , shape = A .shape )
1467+ super (_PowerLinearOperator , self ).__init__ (dtype = A . dtype , shape = A .shape )
14691468 self .args = (A , p )
14701469
14711470 def _power (self , fun : Callable , x : NDArray ) -> NDArray :
0 commit comments