Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,28 @@ In the following, we provide a list of modules (i.e., operators and solvers) whe
* - modules
- NCCL supported
* - :class:`pylops_mpi.DistributedArray`
- /
-
* - :class:`pylops_mpi.basicoperators.MPIVStack`
- Ongoing
- ✅
* - :class:`pylops_mpi.basicoperators.MPIVStack`
- ✅
* - :class:`pylops_mpi.basicoperators.MPIHStack`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIBlockDiag`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIGradient`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIFirstDerivative`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPISecondDerivative`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPILaplacian`
- Ongoing
-
* - :class:`pylops_mpi.optimization.basic.cg`
- Ongoing
-
* - :class:`pylops_mpi.optimization.basic.cgls`
- Ongoing
-
* - ISTA Solver
- Planned
- Planned
* - Complex Numeric Data Type for NCCL
- Planned
- Planned
14 changes: 9 additions & 5 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def asarray(self, masked: bool = False):
return self.local_array

if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_asarray(self.sub_comm if masked else self.base_comm,
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
self.local_array, self.local_shapes, self.axis)
else:
# Gather all the local arrays and apply concatenation.
Expand Down Expand Up @@ -640,9 +640,9 @@ def dot(self, dist_array):
self._check_mask(dist_array)
ncp = get_module(self.engine)
# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
y = DistributedArray.to_dist(x=dist_array.local_array) \
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
# Flatten the local arrays and calculate dot product
return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
Expand Down Expand Up @@ -695,6 +695,7 @@ def zeros_like(self):
"""
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand All @@ -715,7 +716,7 @@ def norm(self, ord: Optional[int] = None,
Axis along which vector norm needs to be computed. Defaults to ``-1``
"""
# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
if axis == -1:
# Flatten the local arrays and calculate norm
Expand All @@ -730,6 +731,7 @@ def conj(self):
"""
conj = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand All @@ -744,6 +746,7 @@ def copy(self):
"""
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand Down Expand Up @@ -899,7 +902,8 @@ def asarray(self):
Global Array gathered at all ranks

"""
return np.hstack([distarr.asarray().ravel() for distarr in self.distarrays])
ncp = get_module(self.distarrays[0].engine)
return ncp.hstack([distarr.asarray().ravel() for distarr in self.distarrays])

def _check_stacked_size(self, stacked_array):
"""Check that arrays have consistent size
Expand Down
12 changes: 9 additions & 3 deletions pylops_mpi/optimization/cls_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from pylops.optimization.basesolver import Solver
from pylops.utils import NDArray
from pylops.utils import NDArray, get_module

from pylops_mpi import DistributedArray, StackedDistributedArray

Expand Down Expand Up @@ -98,7 +98,10 @@ def setup(

if show and self.rank == 0:
if isinstance(x, StackedDistributedArray):
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
# cupy iscomplexobj fallback to numpy iscomplexobject if passing the list
# so it has to be made asarray first
ncp = get_module(x.distarrays[0].engine)
self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays])))
else:
self._print_setup(np.iscomplexobj(x.local_array))
return x
Expand Down Expand Up @@ -354,7 +357,10 @@ def setup(self,
# print setup
if show and self.rank == 0:
if isinstance(x, StackedDistributedArray):
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
# cupy iscomplexobj fallback to numpy iscomplexobject if passing the list
# so it has to be made asarray first
ncp = get_module(x.distarrays[0].engine)
self._print_setup(ncp.iscomplexobj(ncp.asarray([x1.local_array for x1 in x.distarrays])))
else:
self._print_setup(np.iscomplexobj(x.local_array))
return x
Expand Down
Loading