1515nccl_message = deps .nccl_import ("the DistributedArray module" )
1616
1717if nccl_message is None and cupy_message is None :
18- from pylops_mpi .utils ._nccl import nccl_asarray , nccl_bcast , nccl_split
18+ from pylops_mpi .utils ._nccl import nccl_asarray , nccl_split
1919 from cupy .cuda .nccl import NcclCommunicator
2020else :
2121 NcclCommunicator = Any
@@ -204,10 +204,7 @@ def __setitem__(self, index, value):
204204 the specified index positions.
205205 """
206206 if self .partition is Partition .BROADCAST :
207- if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
208- nccl_bcast (self .base_comm_nccl , self .local_array , index , value )
209- else :
210- self .local_array [index ] = self .base_comm .bcast (value )
207+ self ._bcast (self .local_array , index , value )
211208 else :
212209 self .local_array [index ] = value
213210
@@ -343,7 +340,9 @@ def local_shapes(self):
343340 if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
344341 return self ._nccl_local_shapes (False )
345342 else :
346- return self ._allgather (self .local_shape )
343+ return self ._allgather (self .base_comm ,
344+ self .base_comm_nccl ,
345+ self .local_shape )
347346
348347 @property
349348 def sub_comm (self ):
@@ -383,7 +382,10 @@ def asarray(self, masked: bool = False):
383382 if masked :
384383 final_array = self ._allgather_subcomm (self .local_array )
385384 else :
386- final_array = self ._allgather (self .local_array )
385+ final_array = self ._allgather (self .base_comm ,
386+ self .base_comm_nccl ,
387+ self .local_array ,
388+ engine = self .engine )
387389 return np .concatenate (final_array , axis = self .axis )
388390
389391 @classmethod
@@ -433,6 +435,7 @@ def to_dist(cls, x: NDArray,
433435 else :
434436 slices = [slice (None )] * x .ndim
435437 local_shapes = np .append ([0 ], dist_array ._allgather (
438+ base_comm , base_comm_nccl ,
436439 dist_array .local_shape [axis ]))
437440 sum_shapes = np .cumsum (local_shapes )
438441 slices [axis ] = slice (sum_shapes [dist_array .rank ],
@@ -480,7 +483,9 @@ def _nccl_local_shapes(self, masked: bool):
480483 if masked :
481484 all_tuples = self ._allgather_subcomm (self .local_shape ).get ()
482485 else :
483- all_tuples = self ._allgather (self .local_shape ).get ()
486+ all_tuples = self ._allgather (self .base_comm ,
487+ self .base_comm_nccl ,
488+ self .local_shape ).get ()
484489 # NCCL returns the flat array that packs every tuple as 1-dimensional array
485490 # unpack each tuple from each rank
486491 tuple_len = len (self .local_shape )
@@ -578,7 +583,9 @@ def dot(self, dist_array):
578583 y = DistributedArray .to_dist (x = dist_array .local_array , base_comm = self .base_comm , base_comm_nccl = self .base_comm_nccl ) \
579584 if self .partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ] else dist_array
580585 # Flatten the local arrays and calculate dot product
581- return self ._allreduce_subcomm (ncp .dot (x .local_array .flatten (), y .local_array .flatten ()))
586+ return self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
587+ ncp .dot (x .local_array .flatten (), y .local_array .flatten ()),
588+ engine = self .engine )
582589
583590 def _compute_vector_norm (self , local_array : NDArray ,
584591 axis : int , ord : Optional [int ] = None ):
@@ -606,7 +613,9 @@ def _compute_vector_norm(self, local_array: NDArray,
606613 raise ValueError (f"norm-{ ord } not possible for vectors" )
607614 elif ord == 0 :
608615 # Count non-zero then sum reduction
609- recv_buf = self ._allreduce_subcomm (ncp .count_nonzero (local_array , axis = axis ).astype (ncp .float64 ))
616+ recv_buf = self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
617+ ncp .count_nonzero (local_array , axis = axis ).astype (ncp .float64 ),
618+ engine = self .engine )
610619 elif ord == ncp .inf :
611620 # Calculate max followed by max reduction
612621 # CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly
@@ -615,10 +624,14 @@ def _compute_vector_norm(self, local_array: NDArray,
615624 if self .engine == "cupy" and self .base_comm_nccl is None and not deps .cuda_aware_mpi_enabled :
616625 # CuPy + non-CUDA-aware MPI: This will call non-buffered communication
617626 # which return a list of object - must be copied back to a GPU memory.
618- recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MAX )
627+ recv_buf = self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
628+ send_buf .get (), recv_buf .get (),
629+ op = MPI .MAX , engine = self .engine )
619630 recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
620631 else :
621- recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MAX )
632+ recv_buf = self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
633+ send_buf , recv_buf , op = MPI .MAX ,
634+ engine = self .engine )
622635 # TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
623636 # the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
624637 # There may be a way to unify it - may be something to do with how we allocate the recv_buf.
@@ -629,14 +642,20 @@ def _compute_vector_norm(self, local_array: NDArray,
629642 # See the comment above in +infinity norm
630643 send_buf = ncp .min (ncp .abs (local_array ), axis = axis ).astype (ncp .float64 )
631644 if self .engine == "cupy" and self .base_comm_nccl is None and not deps .cuda_aware_mpi_enabled :
632- recv_buf = self ._allreduce_subcomm (send_buf .get (), recv_buf .get (), op = MPI .MIN )
645+ recv_buf = self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
646+ send_buf .get (), recv_buf .get (),
647+ op = MPI .MIN , engine = self .engine )
633648 recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
634649 else :
635- recv_buf = self ._allreduce_subcomm (send_buf , recv_buf , op = MPI .MIN )
650+ recv_buf = self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
651+ send_buf , recv_buf ,
652+ op = MPI .MIN , engine = self .engine )
636653 if self .base_comm_nccl :
637654 recv_buf = ncp .asarray (ncp .squeeze (recv_buf , axis = axis ))
638655 else :
639- recv_buf = self ._allreduce_subcomm (ncp .sum (ncp .abs (ncp .float_power (local_array , ord )), axis = axis ))
656+ recv_buf = self ._allreduce_subcomm (self .sub_comm , self .base_comm_nccl ,
657+ ncp .sum (ncp .abs (ncp .float_power (local_array , ord )), axis = axis ),
658+ engine = self .engine )
640659 recv_buf = ncp .power (recv_buf , 1.0 / ord )
641660 return recv_buf
642661
0 commit comments