Skip to content

Commit 19775a3

Browse files
committed
minor fixes based on recently merged PR
1 parent 30e5afd commit 19775a3

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def asarray(self, masked: bool = False):
380380
return self.local_array
381381

382382
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
383-
return nccl_asarray(self.sub_comm if masked else self.base_comm,
383+
return nccl_asarray(self.sub_comm if masked else self.base_comm_nccl,
384384
self.local_array, self.local_shapes, self.axis)
385385
else:
386386
# Gather all the local arrays and apply concatenation.
@@ -640,9 +640,9 @@ def dot(self, dist_array):
640640
self._check_mask(dist_array)
641641
ncp = get_module(self.engine)
642642
# Convert to Partition.SCATTER if Partition.BROADCAST
643-
x = DistributedArray.to_dist(x=self.local_array, base_comm_nccl=self.base_comm_nccl) \
643+
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
644644
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
645-
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm_nccl=self.base_comm_nccl) \
645+
y = DistributedArray.to_dist(x=dist_array.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
646646
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else dist_array
647647
# Flatten the local arrays and calculate dot product
648648
return self._allreduce_subcomm(ncp.dot(x.local_array.flatten(), y.local_array.flatten()))
@@ -716,7 +716,7 @@ def norm(self, ord: Optional[int] = None,
716716
Axis along which vector norm needs to be computed. Defaults to ``-1``
717717
"""
718718
# Convert to Partition.SCATTER if Partition.BROADCAST
719-
x = DistributedArray.to_dist(x=self.local_array, base_comm_nccl=self.base_comm_nccl) \
719+
x = DistributedArray.to_dist(x=self.local_array, base_comm=self.base_comm, base_comm_nccl=self.base_comm_nccl) \
720720
if self.partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST] else self
721721
if axis == -1:
722722
# Flatten the local arrays and calculate norm

tutorials/poststack_nccl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
BDiag = pylops_mpi.basicoperators.MPIBlockDiag(ops=[Top.H @ PPop @ Top, ])
8181

8282
# This computation will be done in GPU. The call asarray() trigger the NCCL communication (gather result from each GPU).
83-
# But array `d` and `d_0` still lives in GPU memory
83+
# But array `d` and `d_0` still live in GPU memory
8484
d_dist = BDiag @ m3d_dist
8585
d_local = d_dist.local_array.reshape((ny_i, nx, nz))
8686
d = d_dist.asarray().reshape((ny, nx, nz))
@@ -89,8 +89,8 @@
8989

9090
# ###############################################################################
9191

92-
# Inversion using CGLS solver - There is no code change to have run on NCCL
93-
# In this particular case, the local computation will be done in GPU. And the collective communication calls
92+
# Inversion using CGLS solver - There is no code change to have run on NCCL (it handles though MPI operator and DistributedArray)
93+
# In this particular case, the local computation will be done in GPU. Collective communication calls
9494
# will be carried through NCCL GPU-to-GPU.
9595
minv3d_iter_dist = pylops_mpi.optimization.basic.cgls(BDiag, d_dist, x0=mback3d_dist, niter=1, show=True)[0]
9696
minv3d_iter = minv3d_iter_dist.asarray().reshape((ny, nx, nz))

0 commit comments

Comments
 (0)