Skip to content

Commit ca558fd

Browse files
committed
feat: WIP DistributedMix
A new DistributedMix class is create with the aim of simpflify and unify all comm. calls in both DistributedArray and operators (further hiding away all implementation details).
1 parent 31068f9 commit ca558fd

File tree

4 files changed

+68
-63
lines changed

4 files changed

+68
-63
lines changed

pylops_mpi/Distributed.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Any, NewType
2+
3+
from mpi4py import MPI
4+
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
5+
from pylops_mpi.utils._mpi import mpi_allreduce
6+
from pylops_mpi.utils import deps
7+
8+
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
9+
nccl_message = deps.nccl_import("the DistributedArray module")
10+
11+
if nccl_message is None and cupy_message is None:
12+
from pylops_mpi.utils._nccl import (
13+
nccl_allgather, nccl_allreduce,
14+
nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv,
15+
_prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
16+
)
17+
18+
19+
class DistributedMixIn:
20+
r"""Distributed Mixin class
21+
22+
This class implements all methods associated with communication primitives
23+
from MPI and NCCL. It is mostly charged to identifying which commuicator
24+
to use and whether the buffered or object MPI primitives should be used
25+
(the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware
26+
MPI installation is available, the latter with CuPy arrays when a CUDA-Aware
27+
MPI installation is not available).
28+
"""
29+
def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
30+
"""Allreduce operation
31+
"""
32+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
33+
return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
34+
else:
35+
return mpi_allreduce(self.base_comm, send_buf,
36+
recv_buf, self.engine, op)
37+
38+
def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
39+
"""Allreduce operation with subcommunicator
40+
"""
41+
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
42+
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
43+
else:
44+
return mpi_allreduce(self.sub_comm, send_buf,
45+
recv_buf, self.engine, op)

pylops_mpi/DistributedArray.py

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from typing import Any, List, Optional, Tuple, Union, NewType
44

55
import numpy as np
6-
import os
76
from mpi4py import MPI
7+
from pylops_mpi.Distributed import DistributedMixIn
88
from pylops.utils import DTypeLike, NDArray
99
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
1010
from pylops.utils._internal import _value_or_sized_to_tuple
1111
from pylops.utils.backend import get_array_module, get_module, get_module_name
12+
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_send
1213
from pylops_mpi.utils import deps
1314

1415
cupy_message = pylops_deps.cupy_import("the DistributedArray module")
@@ -22,10 +23,6 @@
2223

2324
NcclCommunicatorType = NewType("NcclCommunicator", NcclCommunicator)
2425

25-
if int(os.environ.get("PYLOPS_MPI_CUDA_AWARE", 0)):
26-
is_cuda_aware_mpi = True
27-
else:
28-
is_cuda_aware_mpi = False
2926

3027
class Partition(Enum):
3128
r"""Enum class
@@ -104,7 +101,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] =
104101
return sub_comm
105102

106103

107-
class DistributedArray:
104+
class DistributedArray(DistributedMixIn):
108105
r"""Distributed Numpy Arrays
109106
110107
Multidimensional NumPy-like distributed arrays.
@@ -477,44 +474,6 @@ def _check_mask(self, dist_array):
477474
if not np.array_equal(self.mask, dist_array.mask):
478475
raise ValueError("Mask of both the arrays must be same")
479476

480-
def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
481-
"""Allreduce operation
482-
"""
483-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
484-
return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
485-
else:
486-
if is_cuda_aware_mpi or self.engine == "numpy":
487-
ncp = get_module(self.engine)
488-
recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype)
489-
self.base_comm.Allreduce(send_buf, recv_buf, op)
490-
return recv_buf
491-
else:
492-
# CuPy with non-CUDA-aware MPI
493-
if recv_buf is None:
494-
return self.base_comm.allreduce(send_buf, op)
495-
# For MIN and MAX which require recv_buf
496-
self.base_comm.Allreduce(send_buf, recv_buf, op)
497-
return recv_buf
498-
499-
def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
500-
"""Allreduce operation with subcommunicator
501-
"""
502-
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
503-
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
504-
else:
505-
if is_cuda_aware_mpi or self.engine == "numpy":
506-
ncp = get_module(self.engine)
507-
recv_buf = ncp.zeros(send_buf.size, dtype=send_buf.dtype)
508-
self.sub_comm.Allreduce(send_buf, recv_buf, op)
509-
return recv_buf
510-
else:
511-
# CuPy with non-CUDA-aware MPI
512-
if recv_buf is None:
513-
return self.sub_comm.allreduce(send_buf, op)
514-
# For MIN and MAX which require recv_buf
515-
self.sub_comm.Allreduce(send_buf, recv_buf, op)
516-
return recv_buf
517-
518477
def _allgather(self, send_buf, recv_buf=None):
519478
"""Allgather operation
520479
"""
@@ -556,16 +515,9 @@ def _send(self, send_buf, dest, count=None, tag=0):
556515
count = send_buf.size
557516
nccl_send(self.base_comm_nccl, send_buf, dest, count)
558517
else:
559-
if is_cuda_aware_mpi or self.engine == "numpy":
560-
# Determine MPI type based on array dtype
561-
mpi_type = MPI._typedict[send_buf.dtype.char]
562-
if count is None:
563-
count = send_buf.size
564-
self.base_comm.Send([send_buf, count, mpi_type], dest=dest, tag=tag)
565-
else:
566-
# Uses CuPy without CUDA-aware MPI
567-
self.base_comm.send(send_buf, dest, tag)
568-
518+
mpi_send(self.base_comm,
519+
send_buf, dest, count, tag=tag,
520+
engine=self.engine)
569521

570522
def _recv(self, recv_buf=None, source=0, count=None, tag=0):
571523
"""Receive operation
@@ -579,7 +531,7 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0):
579531
return recv_buf
580532
else:
581533
# NumPy + MPI will benefit from buffered communication regardless of MPI installation
582-
if is_cuda_aware_mpi or self.engine == "numpy":
534+
if deps.cuda_aware_mpi_enabled or self.engine == "numpy":
583535
ncp = get_module(self.engine)
584536
if recv_buf is None:
585537
if count is None:
@@ -734,7 +686,7 @@ def _compute_vector_norm(self, local_array: NDArray,
734686
# CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly
735687
# with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
736688
send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64)
737-
if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi:
689+
if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
738690
# CuPy + non-CUDA-aware MPI: This will call non-buffered communication
739691
# which return a list of object - must be copied back to a GPU memory.
740692
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX)
@@ -750,7 +702,7 @@ def _compute_vector_norm(self, local_array: NDArray,
750702
# Calculate min followed by min reduction
751703
# See the comment above in +infinity norm
752704
send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64)
753-
if self.engine == "cupy" and self.base_comm_nccl is None and not is_cuda_aware_mpi:
705+
if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
754706
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN)
755707
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
756708
else:

pylops_mpi/basicoperators/VStack.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Partition,
1616
StackedDistributedArray
1717
)
18+
from pylops_mpi.Distributed import DistributedMixIn
1819
from pylops_mpi.utils.decorators import reshaped
1920
from pylops_mpi.utils import deps
2021

@@ -25,7 +26,7 @@
2526
from pylops_mpi.utils._nccl import nccl_allreduce
2627

2728

28-
class MPIVStack(MPILinearOperator):
29+
class MPIVStack(DistributedMixIn, MPILinearOperator):
2930
r"""MPI VStack Operator
3031
3132
Create a vertical stack of a set of linear operators using MPI. Each rank must
@@ -141,16 +142,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
141142
@reshaped(forward=False, stacking=True)
142143
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
143144
ncp = get_module(x.engine)
144-
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST,
145+
# TODO: consider adding base_comm, base_comm_nccl, engine to the
146+
# input parameters of _allreduce instead of relying on self
147+
self.base_comm, self.base_comm_nccl, self.engine = \
148+
x.base_comm, x.base_comm_nccl, x.engine
149+
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm,
150+
base_comm_nccl=x.base_comm_nccl,
151+
partition=Partition.BROADCAST,
145152
engine=x.engine, dtype=self.dtype)
146153
y1 = []
147154
for iop, oper in enumerate(self.ops):
148155
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
149156
y1 = ncp.sum(ncp.vstack(y1), axis=0)
150-
if deps.nccl_enabled and x.base_comm_nccl:
151-
y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM)
152-
else:
153-
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
157+
y[:] = self._allreduce(y1, op=MPI.SUM)
154158
return y
155159

156160

pylops_mpi/utils/deps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def nccl_import(message: Optional[str] = None) -> str:
3939
return nccl_message
4040

4141

42+
cuda_aware_mpi_enabled: bool = (
43+
True if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 1) == 1) else False
44+
)
45+
4246
nccl_enabled: bool = (
4347
True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False
4448
)

0 commit comments

Comments
 (0)