Skip to content

Commit 3f9b941

Browse files
committed
bug: fix linearoperator for numpy2
1 parent 9da0fd4 commit 3f9b941

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

pylops/linearoperator.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,23 +1242,14 @@ def _get_dtype(
12421242
) -> DTypeLike:
12431243
if dtypes is None:
12441244
dtypes = []
1245-
opdtypes = []
12461245
for obj in operators:
12471246
if obj is not None and hasattr(obj, "dtype"):
1248-
opdtypes.append(obj.dtype)
1249-
return np.find_common_type(opdtypes, dtypes)
1247+
dtypes.append(obj.dtype)
1248+
return np.result_type(*dtypes)
12501249

12511250

12521251
class _ScaledLinearOperator(LinearOperator):
1253-
"""
1254-
Sum Linear Operator
1255-
1256-
Modified version of scipy _ScaledLinearOperator which uses a modified
1257-
_get_dtype where the scalar and operator types are passed separately to
1258-
np.find_common_type. Passing them together does lead to problems when using
1259-
np.float32 operators which are cast to np.float64
1260-
1261-
"""
1252+
"""Scaled Linear Operator"""
12621253

12631254
def __init__(
12641255
self,
@@ -1269,7 +1260,15 @@ def __init__(
12691260
raise ValueError("LinearOperator expected as A")
12701261
if not np.isscalar(alpha):
12711262
raise ValueError("scalar expected as alpha")
1272-
dtype = _get_dtype([A], [type(alpha)])
1263+
if isinstance(alpha, complex) and not np.iscomplexobj(
1264+
np.ones(1, dtype=A.dtype)
1265+
):
1266+
# if the scalar is of complex type but not the operator, find out type
1267+
dtype = _get_dtype([A], [type(alpha)])
1268+
else:
1269+
# if both the scalar and operator are of real or complex type, use type
1270+
# of the operator
1271+
dtype = A.dtype
12731272
super(_ScaledLinearOperator, self).__init__(dtype=dtype, shape=A.shape)
12741273
self.args = (A, alpha)
12751274

@@ -1465,7 +1464,7 @@ def __init__(self, A: LinearOperator, p: int) -> None:
14651464
if not isintlike(p) or p < 0:
14661465
raise ValueError("non-negative integer expected as p")
14671466

1468-
super(_PowerLinearOperator, self).__init__(dtype=_get_dtype([A]), shape=A.shape)
1467+
super(_PowerLinearOperator, self).__init__(dtype=A.dtype, shape=A.shape)
14691468
self.args = (A, p)
14701469

14711470
def _power(self, fun: Callable, x: NDArray) -> NDArray:

0 commit comments

Comments
 (0)