Skip to content

Commit 83c336a

Browse files
authored
Merge pull request #141 from tharittk/stacked_op_array_nccl_test
NCCL support for Stacked op/array & Solver with doc and tutorial update
2 parents a700c63 + 17d9dbe commit 83c336a

File tree

10 files changed

+1463
-18
lines changed

10 files changed

+1463
-18
lines changed

Makefile

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ doc:
5555
rm -rf source/tutorials && rm -rf build &&\
5656
cd .. && sphinx-build -b html docs/source docs/build
5757

58+
doc_nccl:
59+
cp tutorials_nccl/* tutorials/
60+
cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\
61+
rm -rf source/tutorials && rm -rf source/tutorials && rm -rf build &&\
62+
cd .. && sphinx-build -b html docs/source docs/build
63+
rm tutorials/*_nccl.py
64+
5865
docupdate:
5966
cd docs && make html && cd ..
6067

@@ -68,3 +75,7 @@ run_examples:
6875
# Run tutorials using mpi
6976
run_tutorials:
7077
sh mpi_examples.sh tutorials $(NUM_PROCESSES)
78+
79+
# Run tutorials using nccl
80+
run_tutorials_nccl:
81+
sh mpi_examples.sh tutorials_nccl $(NUM_PROCESSES)

docs/source/gpu.rst

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,26 +131,30 @@ In the following, we provide a list of modules (i.e., operators and solvers) whe
131131
* - modules
132132
- NCCL supported
133133
* - :class:`pylops_mpi.DistributedArray`
134-
- /
134+
-
135135
* - :class:`pylops_mpi.basicoperators.MPIVStack`
136-
- Ongoing
136+
- ✅
137+
* - :class:`pylops_mpi.basicoperators.MPIVStack`
138+
- ✅
137139
* - :class:`pylops_mpi.basicoperators.MPIHStack`
138-
- Ongoing
140+
-
139141
* - :class:`pylops_mpi.basicoperators.MPIBlockDiag`
140-
- Ongoing
142+
-
141143
* - :class:`pylops_mpi.basicoperators.MPIGradient`
142-
- Ongoing
144+
-
143145
* - :class:`pylops_mpi.basicoperators.MPIFirstDerivative`
144-
- Ongoing
146+
-
145147
* - :class:`pylops_mpi.basicoperators.MPISecondDerivative`
146-
- Ongoing
148+
-
147149
* - :class:`pylops_mpi.basicoperators.MPILaplacian`
148-
- Ongoing
150+
-
149151
* - :class:`pylops_mpi.optimization.basic.cg`
150-
- Ongoing
152+
-
151153
* - :class:`pylops_mpi.optimization.basic.cgls`
152-
- Ongoing
154+
- ✅
155+
* - :class:`pylops_mpi.signalprocessing.Fredhoml1`
156+
- Planned ⏳
153157
* - ISTA Solver
154-
- Planned
158+
- Planned
155159
* - Complex Numeric Data Type for NCCL
156-
- Planned
160+
- Planned

pylops_mpi/DistributedArray.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,9 +646,9 @@ def dot(self, dist_array):
646646
self._check_mask(dist_array)
647647
ncp = get_module(self.engine)
648648
# Convert to Partition.SCATTER if Partition.BROADCAST
649-
x = DistributedArray.to_dist(x=self.local_array) \
649+
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
650650
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
651-
y = DistributedArray.to_dist(x=dist_array.local_array) \
651+
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
652652
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
653653
# Flatten the local arrays and calculate dot product
654654
return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
@@ -701,6 +701,7 @@ def zeros_like(self):
701701
"""
702702
arr = DistributedArray(global_shape=self.global_shape,
703703
base_comm=self.base_comm,
704+
base_comm_nccl=self.base_comm_nccl,
704705
partition=self.partition,
705706
axis=self.axis,
706707
local_shapes=self.local_shapes,
@@ -721,7 +722,7 @@ def norm(self, ord: Optional[int] = None,
721722
Axis along which vector norm needs to be computed. Defaults to ``-1``
722723
"""
723724
# Convert to Partition.SCATTER if Partition.BROADCAST
724-
x = DistributedArray.to_dist(x=self.local_array) \
725+
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
725726
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
726727
if axis == -1:
727728
# Flatten the local arrays and calculate norm
@@ -736,6 +737,7 @@ def conj(self):
736737
"""
737738
conj = DistributedArray(global_shape=self.global_shape,
738739
base_comm=self.base_comm,
740+
base_comm_nccl=self.base_comm_nccl,
739741
partition=self.partition,
740742
axis=self.axis,
741743
local_shapes=self.local_shapes,
@@ -750,6 +752,7 @@ def copy(self):
750752
"""
751753
arr = DistributedArray(global_shape=self.global_shape,
752754
base_comm=self.base_comm,
755+
base_comm_nccl=self.base_comm_nccl,
753756
partition=self.partition,
754757
axis=self.axis,
755758
local_shapes=self.local_shapes,
@@ -905,7 +908,8 @@ def asarray(self):
905908
Global Array gathered at all ranks
906909
907910
"""
908-
return np.hstack([distarr.asarray().ravel() for distarr in self.distarrays])
911+
ncp = get_module(self.distarrays[0].engine)
912+
return ncp.hstack([distarr.asarray().ravel() for distarr in self.distarrays])
909913

910914
def _check_stacked_size(self, stacked_array):
911915
"""Check that arrays have consistent size

pylops_mpi/optimization/cls_basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def setup(
9898

9999
if show and self.rank == 0:
100100
if isinstance(x, StackedDistributedArray):
101-
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
101+
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
102+
self._print_setup(is_complex)
102103
else:
103104
self._print_setup(np.iscomplexobj(x.local_array))
104105
return x
@@ -354,7 +355,8 @@ def setup(self,
354355
# print setup
355356
if show and self.rank == 0:
356357
if isinstance(x, StackedDistributedArray):
357-
self._print_setup(np.iscomplexobj([x1.local_array for x1 in x.distarrays]))
358+
is_complex = any(np.iscomplexobj(x1.local_array) for x1 in x.distarrays)
359+
self._print_setup(is_complex)
358360
else:
359361
self._print_setup(np.iscomplexobj(x.local_array))
360362
return x

0 commit comments

Comments
 (0)