Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ run_examples:
# Run tutorials using mpi
run_tutorials:
sh mpi_examples.sh tutorials $(NUM_PROCESSES)

# Run tutorials using nccl
run_tutorials_nccl:
sh mpi_examples.sh tutorials_nccl $(NUM_PROCESSES)
28 changes: 16 additions & 12 deletions docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,30 @@ In the following, we provide a list of modules (i.e., operators and solvers) whe
* - modules
- NCCL supported
* - :class:`pylops_mpi.DistributedArray`
- /
-
* - :class:`pylops_mpi.basicoperators.MPIVStack`
- Ongoing
- ✅
* - :class:`pylops_mpi.basicoperators.MPIVStack`
- ✅
* - :class:`pylops_mpi.basicoperators.MPIHStack`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIBlockDiag`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIGradient`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIFirstDerivative`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPISecondDerivative`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPILaplacian`
- Ongoing
-
* - :class:`pylops_mpi.optimization.basic.cg`
- Ongoing
-
* - :class:`pylops_mpi.optimization.basic.cgls`
- Ongoing
- ✅
* - :class:`pylops_mpi.signalprocessing.Fredhoml1`
- Planned ⏳
* - ISTA Solver
- Planned
- Planned
* - Complex Numeric Data Type for NCCL
- Planned
- Planned
12 changes: 8 additions & 4 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,9 @@ def dot(self, dist_array):
self._check_mask(dist_array)
ncp = get_module(self.engine)
# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
y = DistributedArray.to_dist(x=dist_array.local_array) \
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
# Flatten the local arrays and calculate dot product
return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
Expand Down Expand Up @@ -704,6 +704,7 @@ def zeros_like(self):
"""
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand All @@ -724,7 +725,7 @@ def norm(self, ord: Optional[int] = None,
Axis along which vector norm needs to be computed. Defaults to ``-1``
"""
# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
if axis == -1:
# Flatten the local arrays and calculate norm
Expand All @@ -739,6 +740,7 @@ def conj(self):
"""
conj = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand All @@ -753,6 +755,7 @@ def copy(self):
"""
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand Down Expand Up @@ -908,7 +911,8 @@ def asarray(self):
Global Array gathered at all ranks

"""
return np.hstack([distarr.asarray().ravel() for distarr in self.distarrays])
ncp = get_module(self.distarrays[0].engine)
return ncp.hstack([distarr.asarray().ravel() for distarr in self.distarrays])

def _check_stacked_size(self, stacked_array):
"""Check that arrays have consistent size
Expand Down
6 changes: 4 additions & 2 deletions pylops_mpi/optimization/cls_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def setup(

if show and self.rank == 0:
if isinstance(x, StackedDistributedArray):
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
self._print_setup(is_complex)
else:
self._print_setup(np.iscomplexobj(x.local_array))
return x
Expand Down Expand Up @@ -354,7 +355,8 @@ def setup(self,
# print setup
if show and self.rank == 0:
if isinstance(x, StackedDistributedArray):
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
self._print_setup(is_complex)
else:
self._print_setup(np.iscomplexobj(x.local_array))
return x
Expand Down
Loading