|
4 | 4 | import numpy as np |
5 | 5 |
|
6 | 6 | from pylops.optimization.basesolver import Solver |
7 | | -from pylops.utils import NDArray |
| 7 | +from pylops.utils import NDArray, get_module |
8 | 8 |
|
9 | 9 | from pylops_mpi import DistributedArray, StackedDistributedArray |
10 | 10 |
|
@@ -98,7 +98,10 @@ def setup( |
98 | 98 |
|
99 | 99 | if show and self.rank == 0: |
100 | 100 | if isinstance(x, StackedDistributedArray): |
101 | | - self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays])) |
| 101 | + # cupy iscomplexobj fallback to numpy iscomplexobject if passing the list |
| 102 | + # so it has to be made asarray first |
| 103 | + ncp = get_module(x.distarrays[0].engine) |
| 104 | + self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays]))) |
102 | 105 | else: |
103 | 106 | self._print_setup(np.iscomplexobj(x.local_array)) |
104 | 107 | return x |
@@ -354,7 +357,10 @@ def setup(self, |
354 | 357 | # print setup |
355 | 358 | if show and self.rank == 0: |
356 | 359 | if isinstance(x, StackedDistributedArray): |
357 | | - self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays])) |
| 360 | + # cupy iscomplexobj fallback to numpy iscomplexobject if passing the list |
| 361 | + # so it has to be made asarray first |
| 362 | + ncp = get_module(x.distarrays[0].engine) |
| 363 | + self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays]))) |
358 | 364 | else: |
359 | 365 | self._print_setup(np.iscomplexobj(x.local_array)) |
360 | 366 | return x |
|
0 commit comments