Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ doc:
rm -rf source/tutorials && rm -rf build &&\
cd .. && sphinx-build -b html docs/source docs/build

doc_nccl:
cp tutorials_nccl/* tutorials/
cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\
rm -rf source/tutorials && rm -rf source/tutorials && rm -rf build &&\
cd .. && sphinx-build -b html docs/source docs/build
rm tutorials/*_nccl.py

docupdate:
cd docs && make html && cd ..

Expand All @@ -68,3 +75,7 @@ run_examples:
# Run tutorials using mpi
run_tutorials:
sh mpi_examples.sh tutorials $(NUM_PROCESSES)

# Run tutorials using nccl
run_tutorials_nccl:
sh mpi_examples.sh tutorials_nccl $(NUM_PROCESSES)
28 changes: 16 additions & 12 deletions docs/source/gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,26 +131,30 @@ In the following, we provide a list of modules (i.e., operators and solvers) whe
* - modules
- NCCL supported
* - :class:`pylops_mpi.DistributedArray`
- /
-
* - :class:`pylops_mpi.basicoperators.MPIVStack`
- Ongoing
- ✅
* - :class:`pylops_mpi.basicoperators.MPIVStack`
- ✅
* - :class:`pylops_mpi.basicoperators.MPIHStack`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIBlockDiag`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIGradient`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPIFirstDerivative`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPISecondDerivative`
- Ongoing
-
* - :class:`pylops_mpi.basicoperators.MPILaplacian`
- Ongoing
-
* - :class:`pylops_mpi.optimization.basic.cg`
- Ongoing
-
* - :class:`pylops_mpi.optimization.basic.cgls`
- Ongoing
- ✅
* - :class:`pylops_mpi.signalprocessing.Fredhoml1`
- Planned ⏳
* - ISTA Solver
- Planned
- Planned
* - Complex Numeric Data Type for NCCL
- Planned
- Planned
12 changes: 8 additions & 4 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,9 @@ def dot(self, dist_array):
self._check_mask(dist_array)
ncp = get_module(self.engine)
# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
y = DistributedArray.to_dist(x=dist_array.local_array) \
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
# Flatten the local arrays and calculate dot product
return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
Expand Down Expand Up @@ -701,6 +701,7 @@ def zeros_like(self):
"""
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand All @@ -721,7 +722,7 @@ def norm(self, ord: Optional[int] = None,
Axis along which vector norm needs to be computed. Defaults to ``-1``
"""
# Convert to Partition.SCATTER if Partition.BROADCAST
x = DistributedArray.to_dist(x=self.local_array) \
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
if axis == -1:
# Flatten the local arrays and calculate norm
Expand All @@ -736,6 +737,7 @@ def conj(self):
"""
conj = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand All @@ -750,6 +752,7 @@ def copy(self):
"""
arr = DistributedArray(global_shape=self.global_shape,
base_comm=self.base_comm,
base_comm_nccl=self.base_comm_nccl,
partition=self.partition,
axis=self.axis,
local_shapes=self.local_shapes,
Expand Down Expand Up @@ -905,7 +908,8 @@ def asarray(self):
Global Array gathered at all ranks

"""
return np.hstack([distarr.asarray().ravel() for distarr in self.distarrays])
ncp = get_module(self.distarrays[0].engine)
return ncp.hstack([distarr.asarray().ravel() for distarr in self.distarrays])

def _check_stacked_size(self, stacked_array):
"""Check that arrays have consistent size
Expand Down
6 changes: 4 additions & 2 deletions pylops_mpi/optimization/cls_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def setup(

if show and self.rank == 0:
if isinstance(x, StackedDistributedArray):
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
self._print_setup(is_complex)
else:
self._print_setup(np.iscomplexobj(x.local_array))
return x
Expand Down Expand Up @@ -354,7 +355,8 @@ def setup(self,
# print setup
if show and self.rank == 0:
if isinstance(x, StackedDistributedArray):
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
self._print_setup(is_complex)
else:
self._print_setup(np.iscomplexobj(x.local_array))
return x
Expand Down
Loading