22from numbers import Integral
33from typing import List , Optional , Tuple , Union
44
5- import cupy as cp
6- import cupy .cuda .nccl as nccl
75import numpy as np
86from mpi4py import MPI
97from pylops .utils import DTypeLike , NDArray
108from pylops .utils ._internal import _value_or_sized_to_tuple
119from pylops .utils .backend import get_array_module , get_module , get_module_name
12- from pylops_mpi .utils .backend import cupy_to_nccl_dtype , mpi_op_to_nccl
10+ from pylops_mpi .utils .backend import nccl_split , nccl_allgather , nccl_allreduce , nccl_bcast , nccl_asarray
1311
1412
1513class Partition (Enum ):
@@ -62,7 +60,7 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm,
6260 return tuple (local_shape )
6361
6462
65- def subcomm_split (mask , base_comm : MPI .Comm ):
63+ def subcomm_split (mask , base_comm : MPI .Comm = MPI . COMM_WORLD ):
6664 """Create new communicators based on mask
6765 This method creates new NCCL communicators based on ``mask``.
6866 Contrary to MPI, NCCL does not provide support for splitting of a communicator in multiple subcommunicators;
@@ -82,20 +80,12 @@ def subcomm_split(mask, base_comm: MPI.Comm):
8280 -------
8381 Union[mpi4py.MPI.Comm, cupy.cuda.nccl.NcclCommunicator]]: a subcommunicator according to mask
8482 """
85- comm = MPI .COMM_WORLD
86- rank = comm .Get_rank ()
87- sub_comm = comm .Split (color = mask [rank ], key = rank )
88- if isinstance (base_comm , nccl .NcclCommunicator ):
89- sub_rank = sub_comm .Get_rank ()
90- sub_size = sub_comm .Get_size ()
91-
92- if sub_rank == 0 :
93- nccl_id_bytes = nccl .get_unique_id ()
94- else :
95- nccl_id_bytes = None
96- nccl_id_bytes = sub_comm .bcast (nccl_id_bytes , root = 0 )
97- sub_comm = nccl .NcclCommunicator (sub_size , nccl_id_bytes , sub_rank )
98-
83+ if isinstance (base_comm , MPI .Comm ):
84+ comm = MPI .COMM_WORLD
85+ rank = comm .Get_rank ()
86+ sub_comm = comm .Split (color = mask [rank ], key = rank )
87+ else :
88+ sub_comm = nccl_split (mask )
9989 return sub_comm
10090
10191
@@ -138,8 +128,9 @@ class DistributedArray:
138128 Type of elements in input array. Defaults to ``numpy.float64``.
139129 """
140130
131+ # TODO: Type Annotation for base_comm without NCCL import
141132 def __init__ (self , global_shape : Union [Tuple , Integral ],
142- base_comm : Optional [ Union [ MPI . Comm , nccl . NcclCommunicator ]] = MPI .COMM_WORLD ,
133+ base_comm = MPI .COMM_WORLD ,
143134 partition : Partition = Partition .SCATTER , axis : int = 0 ,
144135 local_shapes : Optional [List [Union [Tuple , Integral ]]] = None ,
145136 mask : Optional [List [Integral ]] = None ,
@@ -153,7 +144,7 @@ def __init__(self, global_shape: Union[Tuple, Integral],
153144 if partition not in Partition :
154145 raise ValueError (f"Should be either { Partition .BROADCAST } , "
155146 f"{ Partition .UNSAFE_BROADCAST } or { Partition .SCATTER } " )
156- if isinstance (base_comm , nccl . NcclCommunicator ) and engine != "cupy" :
147+ if not isinstance (base_comm , MPI . Comm ) and engine != "cupy" :
157148 raise ValueError ("NCCL Communicator only works with engine `cupy`" )
158149
159150 self .dtype = dtype
@@ -199,16 +190,7 @@ def __setitem__(self, index, value):
199190 if isinstance (self .base_comm , MPI .Comm ):
200191 self .local_array [index ] = self .base_comm .bcast (value )
201192 else :
202- # NCCL
203- if self .rank == 0 :
204- self .local_array [index ] = value
205- self .base_comm .bcast (
206- self .local_array [index ].data .ptr ,
207- self .local_array [index ].size ,
208- cupy_to_nccl_dtype [str (self .local_array [index ].dtype )],
209- 0 ,
210- cp .cuda .Stream .null .ptr ,
211- )
193+ nccl_bcast (self .base_comm , self .local_array , index , value )
212194 else :
213195 self .local_array [index ] = value
214196
@@ -375,34 +357,7 @@ def asarray(self):
375357 final_array = self ._allgather (self .local_array )
376358 return np .concatenate (final_array , axis = self .axis )
377359 else :
378- sizes_each_dim = list (zip (* self .local_shapes ))
379- # NCCL allGather requires the send_buf to have the same
380- # size for every device
381- send_shape = tuple (map (max , sizes_each_dim ))
382- pad_size = [
383- (0 , send_shape [i ] - self .local_array .shape [i ])
384- for i in range (len (send_shape ))
385- ]
386-
387- send_buf = cp .pad (
388- self .local_array , pad_size , mode = "constant" , constant_values = 0
389- )
390-
391- # NCCL recommends to use one MPI Process per GPU
392- ndev = MPI .COMM_WORLD .Get_size ()
393- recv_buf = cp .zeros (ndev * send_buf .size , dtype = send_buf .dtype )
394- self ._allgather (send_buf , recv_buf )
395-
396- chunk_size = cp .prod (cp .asarray (send_shape ))
397- chunks = [
398- recv_buf [i * chunk_size :(i + 1 ) * chunk_size ] for i in range (ndev )
399- ]
400-
401- for i in range (ndev ):
402- slicing = tuple (slice (0 , end ) for end in self .local_shapes [i ])
403- chunks [i ] = chunks [i ].reshape (send_shape )[slicing ]
404- final_array = cp .concatenate ([chunks [i ] for i in range (ndev )], axis = self .axis )
405- return final_array
360+ return nccl_asarray (self .base_comm , self .local_array , self .local_shapes , self .axis )
406361
407362 @classmethod
408363 def to_dist (cls , x : NDArray ,
@@ -497,22 +452,7 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
497452 self .base_comm .Allreduce (send_buf , recv_buf , op )
498453 return recv_buf
499454 else :
500- # NCCL
501- send_buf = (
502- send_buf if isinstance (send_buf , cp .ndarray ) else cp .asarray (send_buf )
503- )
504- if recv_buf is None :
505- recv_buf = cp .zeros (send_buf .size , dtype = send_buf .dtype )
506-
507- self .base_comm .allReduce (
508- send_buf .data .ptr ,
509- recv_buf .data .ptr ,
510- send_buf .size ,
511- cupy_to_nccl_dtype [str (send_buf .dtype )],
512- mpi_op_to_nccl (op ),
513- cp .cuda .Stream .null .ptr ,
514- )
515- return recv_buf
455+ return nccl_allreduce (self .base_comm , send_buf , recv_buf , op )
516456
517457 def _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
518458
@@ -527,21 +467,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
527467 return recv_buf
528468
529469 else :
530- send_buf = (
531- send_buf if isinstance (send_buf , cp .ndarray ) else cp .asarray (send_buf )
532- )
533- if recv_buf is None :
534- recv_buf = cp .zeros (send_buf .size , dtype = send_buf .dtype )
535-
536- self .sub_comm .allReduce (
537- send_buf .data .ptr ,
538- recv_buf .data .ptr ,
539- send_buf .size ,
540- cupy_to_nccl_dtype [str (send_buf .dtype )],
541- mpi_op_to_nccl (op ),
542- cp .cuda .Stream .null .ptr ,
543- )
544- return recv_buf
470+ return nccl_allreduce (self .sub_comm , send_buf , recv_buf , op )
545471
546472 def _allgather (self , send_buf , recv_buf = None ):
547473 """Allgather operation
@@ -552,24 +478,7 @@ def _allgather(self, send_buf, recv_buf=None):
552478 self .base_comm .Allgather (send_buf , recv_buf )
553479 return recv_buf
554480 else :
555- # NCCL
556- # Wrap primitive type to cupy array
557- send_buf = (
558- send_buf if isinstance (send_buf , cp .ndarray ) else cp .asarray (send_buf )
559- )
560- if recv_buf is None :
561- recv_buf = cp .zeros (
562- MPI .COMM_WORLD .Get_size () * send_buf .size ,
563- dtype = send_buf .dtype ,
564- )
565- self .base_comm .allGather (
566- send_buf .data .ptr ,
567- recv_buf .data .ptr ,
568- send_buf .size ,
569- cupy_to_nccl_dtype [str (send_buf .dtype )],
570- cp .cuda .Stream .null .ptr ,
571- )
572- return recv_buf
481+ return nccl_allgather (self .base_comm , send_buf , recv_buf )
573482
574483 def __neg__ (self ):
575484 arr = DistributedArray (global_shape = self .global_shape ,
0 commit comments