11from enum import Enum
22from numbers import Integral
3- from typing import Any , List , Optional , Tuple , Union
3+ from typing import Any , List , Optional , Tuple , Union , NewType
44
55import numpy as np
66from mpi4py import MPI
7- from pylops .utils import DTypeLike , NDArray , deps
7+ from pylops .utils import DTypeLike , NDArray
8+ from pylops .utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
89from pylops .utils ._internal import _value_or_sized_to_tuple
910from pylops .utils .backend import get_array_module , get_module , get_module_name
10- from pylops_mpi .utils import deps as pylops_mpi_deps
11+ from pylops_mpi .utils import deps
1112
12- cupy_message = deps .cupy_import ("the DistributedArray module" )
13- nccl_message = pylops_mpi_deps .nccl_import ("the DistributedArray module" )
13+ cupy_message = pylops_deps .cupy_import ("the DistributedArray module" )
14+ nccl_message = deps .nccl_import ("the DistributedArray module" )
1415
1516if nccl_message is None and cupy_message is None :
1617 from pylops_mpi .utils ._nccl import nccl_allgather , nccl_allreduce , nccl_asarray , nccl_bcast , nccl_split
1718 from cupy .cuda .nccl import NcclCommunicator
1819else :
1920 NcclCommunicator = Any
2021
22+ NcclCommunicatorType = NewType ("NcclCommunicator" , NcclCommunicator )
23+
2124
2225class Partition (Enum ):
2326 r"""Enum class
@@ -57,32 +60,28 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm,
5760 """
5861 if partition in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
5962 local_shape = global_shape
60- # Split the array
6163 else :
62- # TODO: Both NCCL and MPI have to use MPI.COMM_WORLD for shape splitting
63- # So it was decided to make it explicit. This will make base_comm unused.
64+ # Split the array
6465 local_shape = list (global_shape )
65- if MPI . COMM_WORLD . Get_rank () < (global_shape [axis ] % MPI . COMM_WORLD .Get_size ()):
66- local_shape [axis ] = global_shape [axis ] // MPI . COMM_WORLD .Get_size () + 1
66+ if base_comm . Get_rank () < (global_shape [axis ] % base_comm .Get_size ()):
67+ local_shape [axis ] = global_shape [axis ] // base_comm .Get_size () + 1
6768 else :
68- local_shape [axis ] = global_shape [axis ] // MPI . COMM_WORLD .Get_size ()
69+ local_shape [axis ] = global_shape [axis ] // base_comm .Get_size ()
6970 return tuple (local_shape )
7071
7172
72- def subcomm_split (mask , base_comm : MPI .Comm = MPI .COMM_WORLD ):
73+ def subcomm_split (mask , comm : Optional [ Union [ MPI .Comm , NcclCommunicatorType ]] = MPI .COMM_WORLD ):
7374 """Create new communicators based on mask
7475
75- This method creates new NCCL communicators based on ``mask``.
76- Contrary to MPI, NCCL does not provide support for splitting of a communicator
77- in multiple subcommunicators; this is therefore handled explicitly by this method.
76+ This method creates new communicators based on ``mask``.
7877
7978 Parameters
8079 ----------
8180 mask : :obj:`list`
8281 Mask defining subsets of ranks to consider when performing 'global'
8382 operations on the distributed array such as dot product or norm.
8483
85- base_comm : :obj:`mpi4py.MPI.Comm`, optional
84+ comm : :obj:`mpi4py.MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator `, optional
8685 A Communicator over which array is distributed
8786 Defaults to ``mpi4py.MPI.COMM_WORLD``.
8887
@@ -91,12 +90,12 @@ def subcomm_split(mask, base_comm: MPI.Comm = MPI.COMM_WORLD):
9190 sub_comm : :obj:`mpi4py.MPI.Comm` or :obj:`cupy.cuda.nccl.NcclCommunicator`
9291 Subcommunicator according to mask
9392 """
94- if isinstance (base_comm , MPI .Comm ):
95- comm = MPI .COMM_WORLD
93+ # NcclCommunicatorType cannot be used with isinstance() so check the negate of MPI.Comm
94+ if deps .nccl_enabled and not isinstance (comm , MPI .Comm ):
95+ sub_comm = nccl_split (mask )
96+ else :
9697 rank = comm .Get_rank ()
9798 sub_comm = comm .Split (color = mask [rank ], key = rank )
98- else :
99- sub_comm = nccl_split (mask )
10099 return sub_comm
101100
102101
@@ -124,6 +123,8 @@ class DistributedArray:
124123 base_comm : :obj:`mpi4py.MPI.Comm`, optional
125124 MPI Communicator over which array is distributed.
126125 Defaults to ``mpi4py.MPI.COMM_WORLD``.
126+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
127+ NCCL Communicator over which array is distributed.
127128 partition : :obj:`Partition`, optional
128129 Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``.
129130 axis : :obj:`int`, optional
@@ -140,7 +141,8 @@ class DistributedArray:
140141 """
141142
142143 def __init__ (self , global_shape : Union [Tuple , Integral ],
143- base_comm : Optional [Union [MPI .Comm , NcclCommunicator ]] = MPI .COMM_WORLD ,
144+ base_comm : Optional [MPI .Comm ] = MPI .COMM_WORLD ,
145+ base_comm_nccl : Optional [NcclCommunicatorType ] = None ,
144146 partition : Partition = Partition .SCATTER , axis : int = 0 ,
145147 local_shapes : Optional [List [Union [Tuple , Integral ]]] = None ,
146148 mask : Optional [List [Integral ]] = None ,
@@ -154,16 +156,17 @@ def __init__(self, global_shape: Union[Tuple, Integral],
154156 if partition not in Partition :
155157 raise ValueError (f"Should be either { Partition .BROADCAST } , "
156158 f"{ Partition .UNSAFE_BROADCAST } or { Partition .SCATTER } " )
157- if not isinstance ( base_comm , MPI . Comm ) and engine != "cupy" :
159+ if base_comm_nccl and engine != "cupy" :
158160 raise ValueError ("NCCL Communicator only works with engine `cupy`" )
159161
160162 self .dtype = dtype
161163 self ._global_shape = _value_or_sized_to_tuple (global_shape )
162164 self ._base_comm = base_comm
165+ self ._base_comm_nccl = base_comm_nccl
163166 self ._partition = partition
164167 self ._axis = axis
165168 self ._mask = mask
166- self ._sub_comm = base_comm if mask is None else subcomm_split (mask , base_comm )
169+ self ._sub_comm = ( base_comm if base_comm_nccl is None else base_comm_nccl ) if mask is None else subcomm_split (mask , ( base_comm if base_comm_nccl is None else base_comm_nccl ) )
167170 local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple (local_shape ) for local_shape in local_shapes ]
168171 self ._check_local_shapes (local_shapes )
169172 self ._local_shape = local_shapes [self .rank ] if local_shapes else local_split (global_shape , base_comm ,
@@ -195,12 +198,11 @@ def __setitem__(self, index, value):
195198 Represents the value that will be assigned to the local array at
196199 the specified index positions.
197200 """
198-
199201 if self .partition is Partition .BROADCAST :
200- if isinstance (self . base_comm , MPI . Comm ):
201- self .local_array [ index ] = self .base_comm . bcast ( value )
202+ if deps . nccl_enabled and getattr (self , "base_comm_nccl" ):
203+ nccl_bcast ( self .base_comm_nccl , self .local_array , index , value )
202204 else :
203- nccl_bcast ( self .base_comm , self . local_array , index , value )
205+ self .local_array [ index ] = self . base_comm . bcast ( value )
204206 else :
205207 self .local_array [index ] = value
206208
@@ -224,6 +226,16 @@ def base_comm(self):
224226 """
225227 return self ._base_comm
226228
229+ @property
230+ def base_comm_nccl (self ):
231+ """Base NCCL Communicator
232+
233+ Returns
234+ -------
235+ base_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
236+ """
237+ return self ._base_comm_nccl
238+
227239 @property
228240 def local_shape (self ):
229241 """Local Shape of the Distributed array
@@ -327,23 +339,23 @@ def local_shapes(self):
327339 -------
328340 local_shapes : :obj:`list`
329341 """
330- if self .base_comm is MPI .COMM_WORLD :
331- return self ._allgather (self .local_shape )
332- else :
342+ if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
333343 # gather tuple of shapes from every rank and copy from GPU to CPU
334344 all_tuples = self ._allgather (self .local_shape ).get ()
335345 # NCCL returns the flat array that packs every tuple as 1-dimensional array
336346 # unpack each tuple from each rank
337347 tuple_len = len (self .local_shape )
338348 return [tuple (all_tuples [i : i + tuple_len ]) for i in range (0 , len (all_tuples ), tuple_len )]
349+ else :
350+ return self ._allgather (self .local_shape )
339351
340352 @property
341353 def sub_comm (self ):
342354 """MPI Sub-Communicator
343355
344356 Returns
345357 -------
346- sub_comm : :obj:`MPI.Comm`
358+ sub_comm : :obj:`MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator`
347359 """
348360 return self ._sub_comm
349361
@@ -362,16 +374,17 @@ def asarray(self):
362374 # Get only self.local_array.
363375 return self .local_array
364376
365- if isinstance (self .base_comm , MPI .Comm ):
377+ if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
378+ return nccl_asarray (self .base_comm_nccl , self .local_array , self .local_shapes , self .axis )
379+ else :
366380 # Gather all the local arrays and apply concatenation.
367381 final_array = self ._allgather (self .local_array )
368382 return np .concatenate (final_array , axis = self .axis )
369- else :
370- return nccl_asarray (self .base_comm , self .local_array , self .local_shapes , self .axis )
371383
372384 @classmethod
373385 def to_dist (cls , x : NDArray ,
374386 base_comm : MPI .Comm = MPI .COMM_WORLD ,
387+ base_comm_nccl : NcclCommunicatorType = None ,
375388 partition : Partition = Partition .SCATTER ,
376389 axis : int = 0 ,
377390 local_shapes : Optional [List [Tuple ]] = None ,
@@ -383,7 +396,9 @@ def to_dist(cls, x: NDArray,
383396 x : :obj:`numpy.ndarray`
384397 Global array.
385398 base_comm : :obj:`MPI.Comm`, optional
386- Type of elements in input array. Defaults to ``MPI.COMM_WORLD``
399+ MPI base communicator
400+ base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
401+ NCCL base communicator
387402 partition : :obj:`Partition`, optional
388403 Distributes the array, Defaults to ``Partition.Scatter``.
389404 axis : :obj:`int`, optional
@@ -401,6 +416,7 @@ def to_dist(cls, x: NDArray,
401416 """
402417 dist_array = DistributedArray (global_shape = x .shape ,
403418 base_comm = base_comm ,
419+ base_comm_nccl = base_comm_nccl ,
404420 partition = partition ,
405421 axis = axis ,
406422 local_shapes = local_shapes ,
@@ -455,39 +471,37 @@ def _check_mask(self, dist_array):
455471 def _allreduce (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
456472 """Allreduce operation
457473 """
458- if isinstance (self .base_comm , MPI .Comm ):
474+ if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
475+ return nccl_allreduce (self .base_comm_nccl , send_buf , recv_buf , op )
476+ else :
459477 if recv_buf is None :
460478 return self .base_comm .allreduce (send_buf , op )
461479 # For MIN and MAX which require recv_buf
462480 self .base_comm .Allreduce (send_buf , recv_buf , op )
463481 return recv_buf
464- else :
465- return nccl_allreduce (self .base_comm , send_buf , recv_buf , op )
466482
467483 def _allreduce_subcomm (self , send_buf , recv_buf = None , op : MPI .Op = MPI .SUM ):
468484 """Allreduce operation with subcommunicator
469485 """
470-
471- if isinstance (self .base_comm , MPI .Comm ):
486+ if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
487+ return nccl_allreduce (self .sub_comm , send_buf , recv_buf , op )
488+ else :
472489 if recv_buf is None :
473490 return self .sub_comm .allreduce (send_buf , op )
474491 # For MIN and MAX which require recv_buf
475492 self .sub_comm .Allreduce (send_buf , recv_buf , op )
476493 return recv_buf
477494
478- else :
479- return nccl_allreduce (self .sub_comm , send_buf , recv_buf , op )
480-
481495 def _allgather (self , send_buf , recv_buf = None ):
482496 """Allgather operation
483497 """
484- if isinstance (self .base_comm , MPI .Comm ):
498+ if deps .nccl_enabled and getattr (self , "base_comm_nccl" ):
499+ return nccl_allgather (self .base_comm_nccl , send_buf , recv_buf )
500+ else :
485501 if recv_buf is None :
486502 return self .base_comm .allgather (send_buf )
487503 self .base_comm .Allgather (send_buf , recv_buf )
488504 return recv_buf
489- else :
490- return nccl_allgather (self .base_comm , send_buf , recv_buf )
491505
492506 def __neg__ (self ):
493507 arr = DistributedArray (global_shape = self .global_shape ,
0 commit comments