Skip to content

Commit 5faa1b1

Browse files
committed
nccl solver and poststack inv tutorial draft
1 parent 6d6f9dc commit 5faa1b1

File tree

3 files changed

+617
-3
lines changed

3 files changed

+617
-3
lines changed

pylops_mpi/optimization/cls_basic.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from pylops.optimization.basesolver import Solver
7-
from pylops.utils import NDArray
7+
from pylops.utils import NDArray, get_module
88

99
from pylops_mpi import DistributedArray, StackedDistributedArray
1010

@@ -98,7 +98,10 @@ def setup(
9898

9999
if show and self.rank == 0:
100100
if isinstance(x, StackedDistributedArray):
101-
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
101+
# cupy iscomplexobj fallback to numpy iscomplexobject if passing the list
102+
# so it has to be made asarray first
103+
ncp = get_module(x.distarrays[0].engine)
104+
self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays])))
102105
else:
103106
self._print_setup(np.iscomplexobj(x.local_array))
104107
return x
@@ -354,7 +357,10 @@ def setup(self,
354357
# print setup
355358
if show and self.rank == 0:
356359
if isinstance(x, StackedDistributedArray):
357-
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
360+
# cupy iscomplexobj fallback to numpy iscomplexobject if passing the list
361+
# so it has to be made asarray first
362+
ncp = get_module(x.distarrays[0].engine)
363+
self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays])))
358364
else:
359365
self._print_setup(np.iscomplexobj(x.local_array))
360366
return x

0 commit comments

Comments
 (0)