Skip to content

Commit 3dc41fe

Browse files
committed
Change DistributedArray() to take base_comm and base_comm_nccl as suggested by PR
1 parent cdc5950 commit 3dc41fe

File tree

3 files changed

+77
-63
lines changed

3 files changed

+77
-63
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
from enum import Enum
22
from numbers import Integral
3-
from typing import Any, List, Optional, Tuple, Union
3+
from typing import Any, List, Optional, Tuple, Union, NewType
44

55
import numpy as np
66
from mpi4py import MPI
7-
from pylops.utils import DTypeLike, NDArray, deps
7+
from pylops.utils import DTypeLike, NDArray
8+
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
89
from pylops.utils._internal import _value_or_sized_to_tuple
910
from pylops.utils.backend import get_array_module, get_module, get_module_name
10-
from pylops_mpi.utils import deps as pylops_mpi_deps
11+
from pylops_mpi.utils import deps
1112

12-
cupy_message = deps.cupy_import("the DistributedArray module")
13-
nccl_message = pylops_mpi_deps.nccl_import("the DistributedArray module")
13+
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
14+
nccl_message = deps.nccl_import("the DistributedArray module")
1415

1516
if nccl_message is None and cupy_message is None:
1617
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split
1718
from cupy.cuda.nccl import NcclCommunicator
1819
else:
1920
NcclCommunicator = Any
2021

22+
NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator)
23+
2124

2225
class Partition(Enum):
2326
r"""Enum class
@@ -57,32 +60,28 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm,
5760
"""
5861
if partition in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
5962
local_shape = global_shape
60-
# Split the array
6163
else:
62-
# TODO: Both NCCL and MPI have to use MPI.COMM_WORLD for shape splitting
63-
# So it was decided to make it explicit. This will make base_comm unused.
64+
# Split the array
6465
local_shape = list(global_shape)
65-
if MPI.COMM_WORLD.Get_rank() < (global_shape[axis] % MPI.COMM_WORLD.Get_size()):
66-
local_shape[axis] = global_shape[axis] // MPI.COMM_WORLD.Get_size() + 1
66+
if base_comm.Get_rank() < (global_shape[axis] % base_comm.Get_size()):
67+
local_shape[axis] = global_shape[axis] // base_comm.Get_size() + 1
6768
else:
68-
local_shape[axis] = global_shape[axis] // MPI.COMM_WORLD.Get_size()
69+
local_shape[axis] = global_shape[axis] // base_comm.Get_size()
6970
return tuple(local_shape)
7071

7172

72-
def subcomm_split(mask, base_comm: MPI.Comm = MPI.COMM_WORLD):
73+
def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] = MPI.COMM_WORLD):
7374
"""Create new communicators based on mask
7475
75-
This method creates new NCCL communicators based on ``mask``.
76-
Contrary to MPI, NCCL does not provide support for splitting of a communicator
77-
in multiple subcommunicators; this is therefore handled explicitly by this method.
76+
This method creates new communicators based on ``mask``.
7877
7978
Parameters
8079
----------
8180
mask : :obj:`list`
8281
Mask defining subsets of ranks to consider when performing 'global'
8382
operations on the distributed array such as dot product or norm.
8483
85-
base_comm : :obj:`mpi4py.MPI.Comm`, optional
84+
comm : :obj:`mpi4py.MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator`, optional
8685
A Communicator over which array is distributed
8786
Defaults to ``mpi4py.MPI.COMM_WORLD``.
8887
@@ -91,12 +90,12 @@ def subcomm_split(mask, base_comm: MPI.Comm = MPI.COMM_WORLD):
9190
sub_comm : :obj:`mpi4py.MPI.Comm` or :obj:`cupy.cuda.nccl.NcclCommunicator`
9291
Subcommunicator according to mask
9392
"""
94-
if isinstance(base_comm, MPI.Comm):
95-
comm = MPI.COMM_WORLD
93+
# NcclCommunicatorType cannot be used with isinstance() so check the negate of MPI.Comm
94+
if deps.nccl_enabled and not isinstance(comm, MPI.Comm):
95+
sub_comm = nccl_split(mask)
96+
else:
9697
rank = comm.Get_rank()
9798
sub_comm = comm.Split(color=mask[rank], key=rank)
98-
else:
99-
sub_comm = nccl_split(mask)
10099
return sub_comm
101100

102101

@@ -124,6 +123,8 @@ class DistributedArray:
124123
base_comm : :obj:`mpi4py.MPI.Comm`, optional
125124
MPI Communicator over which array is distributed.
126125
Defaults to ``mpi4py.MPI.COMM_WORLD``.
126+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
127+
NCCL Communicator over which array is distributed.
127128
partition : :obj:`Partition`, optional
128129
Broadcast, UnsafeBroadcast, or Scatter the array. Defaults to ``Partition.SCATTER``.
129130
axis : :obj:`int`, optional
@@ -140,7 +141,8 @@ class DistributedArray:
140141
"""
141142

142143
def __init__(self, global_shape: Union[Tuple, Integral],
143-
base_comm: Optional[Union[MPI.Comm, NcclCommunicator]] = MPI.COMM_WORLD,
144+
base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD,
145+
base_comm_nccl: Optional[NcclCommunicatorType] = None,
144146
partition: Partition = Partition.SCATTER, axis: int = 0,
145147
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
146148
mask: Optional[List[Integral]] = None,
@@ -154,16 +156,17 @@ def __init__(self, global_shape: Union[Tuple, Integral],
154156
if partition not in Partition:
155157
raise ValueError(f"Should be either {Partition.BROADCAST}, "
156158
f"{Partition.UNSAFE_BROADCAST} or {Partition.SCATTER}")
157-
if not isinstance(base_comm, MPI.Comm) and engine != "cupy":
159+
if base_comm_nccl and engine != "cupy":
158160
raise ValueError("NCCL Communicator only works with engine `cupy`")
159161

160162
self.dtype = dtype
161163
self._global_shape = _value_or_sized_to_tuple(global_shape)
162164
self._base_comm = base_comm
165+
self._base_comm_nccl = base_comm_nccl
163166
self._partition = partition
164167
self._axis = axis
165168
self._mask = mask
166-
self._sub_comm = base_comm if mask is None else subcomm_split(mask, base_comm)
169+
self._sub_comm = (base_comm if base_comm_nccl is None else base_comm_nccl) if mask is None else subcomm_split(mask, (base_comm if base_comm_nccl is None else base_comm_nccl))
167170
local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes]
168171
self._check_local_shapes(local_shapes)
169172
self._local_shape = local_shapes[self.rank] if local_shapes else local_split(global_shape, base_comm,
@@ -195,12 +198,11 @@ def __setitem__(self, index, value):
195198
Represents the value that will be assigned to the local array at
196199
the specified index positions.
197200
"""
198-
199201
if self.partition is Partition.BROADCAST:
200-
if isinstance(self.base_comm, MPI.Comm):
201-
self.local_array[index] = self.base_comm.bcast(value)
202+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
203+
nccl_bcast(self.base_comm_nccl, self.local_array, index, value)
202204
else:
203-
nccl_bcast(self.base_comm, self.local_array, index, value)
205+
self.local_array[index] = self.base_comm.bcast(value)
204206
else:
205207
self.local_array[index] = value
206208

@@ -224,6 +226,16 @@ def base_comm(self):
224226
"""
225227
return self._base_comm
226228

229+
@property
230+
def base_comm_nccl(self):
231+
"""Base NCCL Communicator
232+
233+
Returns
234+
-------
235+
base_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
236+
"""
237+
return self._base_comm_nccl
238+
227239
@property
228240
def local_shape(self):
229241
"""Local Shape of the Distributed array
@@ -327,23 +339,23 @@ def local_shapes(self):
327339
-------
328340
local_shapes : :obj:`list`
329341
"""
330-
if self.base_comm is MPI.COMM_WORLD:
331-
return self._allgather(self.local_shape)
332-
else:
342+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
333343
# gather tuple of shapes from every rank and copy from GPU to CPU
334344
all_tuples = self._allgather(self.local_shape).get()
335345
# NCCL returns the flat array that packs every tuple as 1-dimensional array
336346
# unpack each tuple from each rank
337347
tuple_len = len(self.local_shape)
338348
return [tuple(all_tuples[i : i + tuple_len]) for i in range(0, len(all_tuples), tuple_len)]
349+
else:
350+
return self._allgather(self.local_shape)
339351

340352
@property
341353
def sub_comm(self):
342354
"""MPI Sub-Communicator
343355
344356
Returns
345357
-------
346-
sub_comm : :obj:`MPI.Comm`
358+
sub_comm : :obj:`MPI.Comm` or `cupy.cuda.nccl.NcclCommunicator`
347359
"""
348360
return self._sub_comm
349361

@@ -362,16 +374,17 @@ def asarray(self):
362374
# Get only self.local_array.
363375
return self.local_array
364376

365-
if isinstance(self.base_comm, MPI.Comm):
377+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
378+
return nccl_asarray(self.base_comm_nccl, self.local_array, self.local_shapes, self.axis)
379+
else:
366380
# Gather all the local arrays and apply concatenation.
367381
final_array = self._allgather(self.local_array)
368382
return np.concatenate(final_array, axis=self.axis)
369-
else:
370-
return nccl_asarray(self.base_comm, self.local_array, self.local_shapes, self.axis)
371383

372384
@classmethod
373385
def to_dist(cls, x: NDArray,
374386
base_comm: MPI.Comm = MPI.COMM_WORLD,
387+
base_comm_nccl: NcclCommunicatorType = None,
375388
partition: Partition = Partition.SCATTER,
376389
axis: int = 0,
377390
local_shapes: Optional[List[Tuple]] = None,
@@ -383,7 +396,9 @@ def to_dist(cls, x: NDArray,
383396
x : :obj:`numpy.ndarray`
384397
Global array.
385398
base_comm : :obj:`MPI.Comm`, optional
386-
Type of elements in input array. Defaults to ``MPI.COMM_WORLD``
399+
MPI base communicator
400+
base_comm_nccl : :obj:`cupy.cuda.nccl.NcclCommunicator`, optional
401+
NCCL base communicator
387402
partition : :obj:`Partition`, optional
388403
Distributes the array, Defaults to ``Partition.Scatter``.
389404
axis : :obj:`int`, optional
@@ -401,6 +416,7 @@ def to_dist(cls, x: NDArray,
401416
"""
402417
dist_array = DistributedArray(global_shape=x.shape,
403418
base_comm=base_comm,
419+
base_comm_nccl=base_comm_nccl,
404420
partition=partition,
405421
axis=axis,
406422
local_shapes=local_shapes,
@@ -455,39 +471,37 @@ def _check_mask(self, dist_array):
455471
def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
456472
"""Allreduce operation
457473
"""
458-
if isinstance(self.base_comm, MPI.Comm):
474+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
475+
return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
476+
else:
459477
if recv_buf is None:
460478
return self.base_comm.allreduce(send_buf, op)
461479
# For MIN and MAX which require recv_buf
462480
self.base_comm.Allreduce(send_buf, recv_buf, op)
463481
return recv_buf
464-
else:
465-
return nccl_allreduce(self.base_comm, send_buf, recv_buf, op)
466482

467483
def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
468484
"""Allreduce operation with subcommunicator
469485
"""
470-
471-
if isinstance(self.base_comm, MPI.Comm):
486+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
487+
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
488+
else:
472489
if recv_buf is None:
473490
return self.sub_comm.allreduce(send_buf, op)
474491
# For MIN and MAX which require recv_buf
475492
self.sub_comm.Allreduce(send_buf, recv_buf, op)
476493
return recv_buf
477494

478-
else:
479-
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
480-
481495
def _allgather(self, send_buf, recv_buf=None):
482496
"""Allgather operation
483497
"""
484-
if isinstance(self.base_comm, MPI.Comm):
498+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
499+
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
500+
else:
485501
if recv_buf is None:
486502
return self.base_comm.allgather(send_buf)
487503
self.base_comm.Allgather(send_buf, recv_buf)
488504
return recv_buf
489-
else:
490-
return nccl_allgather(self.base_comm, send_buf, recv_buf)
491505

492506
def __neg__(self):
493507
arr = DistributedArray(global_shape=self.global_shape,

tests/test_distributedarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_distributed_maskeddot(par1, par2):
201201
"""Test Distributed Dot product with masked array"""
202202
# number of subcommunicators
203203
if MPI.COMM_WORLD.Get_size() % 2 == 0:
204-
nsub = 2
204+
nsub = 2
205205
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
206206
nsub = 3
207207
else:
@@ -236,7 +236,7 @@ def test_distributed_maskednorm(par):
236236
"""Test Distributed numpy.linalg.norm method with masked array"""
237237
# number of subcommunicators
238238
if MPI.COMM_WORLD.Get_size() % 2 == 0:
239-
nsub = 2
239+
nsub = 2
240240
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
241241
nsub = 3
242242
else:

0 commit comments

Comments
 (0)