Skip to content

Commit f1238cb

Browse files
committed
move nccl-related calls to backend.py to avoid direct import
1 parent bf15ea0 commit f1238cb

File tree

2 files changed

+207
-111
lines changed

2 files changed

+207
-111
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 16 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
from numbers import Integral
33
from typing import List, Optional, Tuple, Union
44

5-
import cupy as cp
6-
import cupy.cuda.nccl as nccl
75
import numpy as np
86
from mpi4py import MPI
97
from pylops.utils import DTypeLike, NDArray
108
from pylops.utils._internal import _value_or_sized_to_tuple
119
from pylops.utils.backend import get_array_module, get_module, get_module_name
12-
from pylops_mpi.utils.backend import cupy_to_nccl_dtype, mpi_op_to_nccl
10+
from pylops_mpi.utils.backend import nccl_split, nccl_allgather, nccl_allreduce, nccl_bcast, nccl_asarray
1311

1412

1513
class Partition(Enum):
@@ -62,7 +60,7 @@ def local_split(global_shape: Tuple, base_comm: MPI.Comm,
6260
return tuple(local_shape)
6361

6462

65-
def subcomm_split(mask, base_comm: MPI.Comm):
63+
def subcomm_split(mask, base_comm: MPI.Comm = MPI.COMM_WORLD):
6664
"""Create new communicators based on mask
6765
This method creates new NCCL communicators based on ``mask``.
6866
Contrary to MPI, NCCL does not provide support for splitting of a communicator in multiple subcommunicators;
@@ -82,20 +80,12 @@ def subcomm_split(mask, base_comm: MPI.Comm):
8280
-------
8381
Union[mpi4py.MPI.Comm, cupy.cuda.nccl.NcclCommunicator]]: a subcommunicator according to mask
8482
"""
85-
comm = MPI.COMM_WORLD
86-
rank = comm.Get_rank()
87-
sub_comm = comm.Split(color=mask[rank], key=rank)
88-
if isinstance(base_comm, nccl.NcclCommunicator):
89-
sub_rank = sub_comm.Get_rank()
90-
sub_size = sub_comm.Get_size()
91-
92-
if sub_rank == 0:
93-
nccl_id_bytes = nccl.get_unique_id()
94-
else:
95-
nccl_id_bytes = None
96-
nccl_id_bytes = sub_comm.bcast(nccl_id_bytes, root=0)
97-
sub_comm = nccl.NcclCommunicator(sub_size, nccl_id_bytes, sub_rank)
98-
83+
if isinstance(base_comm, MPI.Comm):
84+
comm = MPI.COMM_WORLD
85+
rank = comm.Get_rank()
86+
sub_comm = comm.Split(color=mask[rank], key=rank)
87+
else:
88+
sub_comm = nccl_split(mask)
9989
return sub_comm
10090

10191

@@ -138,8 +128,9 @@ class DistributedArray:
138128
Type of elements in input array. Defaults to ``numpy.float64``.
139129
"""
140130

131+
# TODO: Type Annotation for base_comm without NCCL import
141132
def __init__(self, global_shape: Union[Tuple, Integral],
142-
base_comm: Optional[Union[MPI.Comm, nccl.NcclCommunicator]] = MPI.COMM_WORLD,
133+
base_comm=MPI.COMM_WORLD,
143134
partition: Partition = Partition.SCATTER, axis: int = 0,
144135
local_shapes: Optional[List[Union[Tuple, Integral]]] = None,
145136
mask: Optional[List[Integral]] = None,
@@ -153,7 +144,7 @@ def __init__(self, global_shape: Union[Tuple, Integral],
153144
if partition not in Partition:
154145
raise ValueError(f"Should be either {Partition.BROADCAST}, "
155146
f"{Partition.UNSAFE_BROADCAST} or {Partition.SCATTER}")
156-
if isinstance(base_comm, nccl.NcclCommunicator) and engine != "cupy":
147+
if not isinstance(base_comm, MPI.Comm) and engine != "cupy":
157148
raise ValueError("NCCL Communicator only works with engine `cupy`")
158149

159150
self.dtype = dtype
@@ -199,16 +190,7 @@ def __setitem__(self, index, value):
199190
if isinstance(self.base_comm, MPI.Comm):
200191
self.local_array[index] = self.base_comm.bcast(value)
201192
else:
202-
# NCCL
203-
if self.rank == 0:
204-
self.local_array[index] = value
205-
self.base_comm.bcast(
206-
self.local_array[index].data.ptr,
207-
self.local_array[index].size,
208-
cupy_to_nccl_dtype[str(self.local_array[index].dtype)],
209-
0,
210-
cp.cuda.Stream.null.ptr,
211-
)
193+
nccl_bcast(self.base_comm, self.local_array, index, value)
212194
else:
213195
self.local_array[index] = value
214196

@@ -375,34 +357,7 @@ def asarray(self):
375357
final_array = self._allgather(self.local_array)
376358
return np.concatenate(final_array, axis=self.axis)
377359
else:
378-
sizes_each_dim = list(zip(*self.local_shapes))
379-
# NCCL allGather requires the send_buf to have the same
380-
# size for every device
381-
send_shape = tuple(map(max, sizes_each_dim))
382-
pad_size = [
383-
(0, send_shape[i] - self.local_array.shape[i])
384-
for i in range(len(send_shape))
385-
]
386-
387-
send_buf = cp.pad(
388-
self.local_array, pad_size, mode="constant", constant_values=0
389-
)
390-
391-
# NCCL recommends to use one MPI Process per GPU
392-
ndev = MPI.COMM_WORLD.Get_size()
393-
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
394-
self._allgather(send_buf, recv_buf)
395-
396-
chunk_size = cp.prod(cp.asarray(send_shape))
397-
chunks = [
398-
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
399-
]
400-
401-
for i in range(ndev):
402-
slicing = tuple(slice(0, end) for end in self.local_shapes[i])
403-
chunks[i] = chunks[i].reshape(send_shape)[slicing]
404-
final_array = cp.concatenate([chunks[i] for i in range(ndev)], axis=self.axis)
405-
return final_array
360+
return nccl_asarray(self.base_comm, self.local_array, self.local_shapes, self.axis)
406361

407362
@classmethod
408363
def to_dist(cls, x: NDArray,
@@ -497,22 +452,7 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
497452
self.base_comm.Allreduce(send_buf, recv_buf, op)
498453
return recv_buf
499454
else:
500-
# NCCL
501-
send_buf = (
502-
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
503-
)
504-
if recv_buf is None:
505-
recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype)
506-
507-
self.base_comm.allReduce(
508-
send_buf.data.ptr,
509-
recv_buf.data.ptr,
510-
send_buf.size,
511-
cupy_to_nccl_dtype[str(send_buf.dtype)],
512-
mpi_op_to_nccl(op),
513-
cp.cuda.Stream.null.ptr,
514-
)
515-
return recv_buf
455+
return nccl_allreduce(self.base_comm, send_buf, recv_buf, op)
516456

517457
def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
518458

@@ -527,21 +467,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
527467
return recv_buf
528468

529469
else:
530-
send_buf = (
531-
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
532-
)
533-
if recv_buf is None:
534-
recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype)
535-
536-
self.sub_comm.allReduce(
537-
send_buf.data.ptr,
538-
recv_buf.data.ptr,
539-
send_buf.size,
540-
cupy_to_nccl_dtype[str(send_buf.dtype)],
541-
mpi_op_to_nccl(op),
542-
cp.cuda.Stream.null.ptr,
543-
)
544-
return recv_buf
470+
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
545471

546472
def _allgather(self, send_buf, recv_buf=None):
547473
"""Allgather operation
@@ -552,24 +478,7 @@ def _allgather(self, send_buf, recv_buf=None):
552478
self.base_comm.Allgather(send_buf, recv_buf)
553479
return recv_buf
554480
else:
555-
# NCCL
556-
# Wrap primitive type to cupy array
557-
send_buf = (
558-
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
559-
)
560-
if recv_buf is None:
561-
recv_buf = cp.zeros(
562-
MPI.COMM_WORLD.Get_size() * send_buf.size,
563-
dtype=send_buf.dtype,
564-
)
565-
self.base_comm.allGather(
566-
send_buf.data.ptr,
567-
recv_buf.data.ptr,
568-
send_buf.size,
569-
cupy_to_nccl_dtype[str(send_buf.dtype)],
570-
cp.cuda.Stream.null.ptr,
571-
)
572-
return recv_buf
481+
return nccl_allgather(self.base_comm, send_buf, recv_buf)
573482

574483
def __neg__(self):
575484
arr = DistributedArray(global_shape=self.global_shape,

0 commit comments

Comments
 (0)