Skip to content

Commit a08924b

Browse files
committed
fix flake8
1 parent dbe1f30 commit a08924b

File tree

4 files changed

+30
-34
lines changed

4 files changed

+30
-34
lines changed

pylops_mpi/Distributed.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Any, NewType, Tuple
2-
31
from mpi4py import MPI
42
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
53
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv
@@ -10,10 +8,10 @@
108

119
if nccl_message is None and cupy_message is None:
1210
from pylops_mpi.utils._nccl import (
13-
nccl_allgather, nccl_allreduce,
14-
nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv
11+
nccl_allgather, nccl_allreduce, nccl_send, nccl_recv
1512
)
1613

14+
1715
class DistributedMixIn:
1816
r"""Distributed Mixin class
1917
@@ -30,7 +28,7 @@ def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
3028
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
3129
return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
3230
else:
33-
return mpi_allreduce(self.base_comm, send_buf,
31+
return mpi_allreduce(self.base_comm, send_buf,
3432
recv_buf, self.engine, op)
3533

3634
def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
@@ -39,7 +37,7 @@ def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
3937
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
4038
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
4139
else:
42-
return mpi_allreduce(self.sub_comm, send_buf,
40+
return mpi_allreduce(self.sub_comm, send_buf,
4341
recv_buf, self.engine, op)
4442

4543
def _allgather(self, send_buf, recv_buf=None):
@@ -96,7 +94,5 @@ def _recv(self, recv_buf=None, source=0, count=None, tag=0):
9694
return recv_buf
9795
else:
9896
return mpi_recv(self.base_comm,
99-
recv_buf, source, count, tag=tag,
100-
engine=self.engine)
101-
102-
97+
recv_buf, source, count, tag=tag,
98+
engine=self.engine)

pylops_mpi/DistributedArray.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
nccl_message = deps.nccl_import("the DistributedArray module")
1616

1717
if nccl_message is None and cupy_message is None:
18-
from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split
18+
from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split
1919
from cupy.cuda.nccl import NcclCommunicator
2020
else:
2121
NcclCommunicator = Any
@@ -613,14 +613,14 @@ def _compute_vector_norm(self, local_array: NDArray,
613613
# with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
614614
send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64)
615615
if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
616-
# CuPy + non-CUDA-aware MPI: This will call non-buffered communication
616+
# CuPy + non-CUDA-aware MPI: This will call non-buffered communication
617617
# which return a list of object - must be copied back to a GPU memory.
618618
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX)
619619
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
620620
else:
621621
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX)
622622
# TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
623-
# the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
623+
# the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
624624
# There may be a way to unify it - may be something to do with how we allocate the recv_buf.
625625
if self.base_comm_nccl:
626626
recv_buf = ncp.squeeze(recv_buf, axis=axis)

pylops_mpi/utils/_mpi.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
"_unroll_allgather_recv"
1010
]
1111

12-
from typing import Optional, Tuple
12+
from typing import Optional
1313

1414
import numpy as np
1515
from mpi4py import MPI
1616
from pylops.utils.backend import get_module
1717
from pylops_mpi.utils import deps
1818

19+
1920
# TODO: return type annotation for both cupy and numpy
2021
def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
2122
r""" Prepare send_buf and recv_buf for NCCL allgather (nccl_allgather)
@@ -33,7 +34,7 @@ def _prepare_allgather_inputs(send_buf, send_buf_shapes, engine):
3334
The data buffer from the local GPU to be sent for allgather.
3435
send_buf_shapes: :obj:`list`
3536
A list of shapes for each GPU send_buf (used to calculate padding size)
36-
engine : :obj:`str`
37+
engine : :obj:`str`
3738
Engine used to store array (``numpy`` or ``cupy``)
3839
3940
Returns
@@ -96,20 +97,21 @@ def _unroll_allgather_recv(recv_buf, padded_send_buf_shape, send_buf_shapes) ->
9697

9798
return chunks
9899

100+
99101
def mpi_allreduce(base_comm: MPI.Comm,
100-
send_buf, recv_buf=None,
102+
send_buf, recv_buf=None,
101103
engine: Optional[str] = "numpy",
102104
op: MPI.Op = MPI.SUM) -> np.ndarray:
103-
"""MPI_Allreduce/allreduce
104-
105-
Dispatch allreduce routine based on type of input and availability of
105+
"""MPI_Allreduce/allreduce
106+
107+
Dispatch allreduce routine based on type of input and availability of
106108
CUDA-Aware MPI
107109
108110
Parameters
109111
----------
110112
base_comm : :obj:`MPI.Comm`
111113
Base MPI Communicator.
112-
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
114+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
113115
The data buffer from the local GPU to be reduced.
114116
recv_buf : :obj:`cupy.ndarray`, optional
115117
The buffer to store the result of the reduction. If None,
@@ -121,10 +123,10 @@ def mpi_allreduce(base_comm: MPI.Comm,
121123
122124
Returns
123125
-------
124-
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
126+
recv_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
125127
A buffer containing the result of the reduction, broadcasted
126128
to all GPUs.
127-
129+
128130
"""
129131
if deps.cuda_aware_mpi_enabled or engine == "numpy":
130132
ncp = get_module(engine)
@@ -141,9 +143,8 @@ def mpi_allreduce(base_comm: MPI.Comm,
141143

142144

143145
def mpi_allgather(base_comm: MPI.Comm,
144-
send_buf, recv_buf=None,
145-
engine: Optional[str] = "numpy",
146-
) -> np.ndarray:
146+
send_buf, recv_buf=None,
147+
engine: Optional[str] = "numpy") -> np.ndarray:
147148

148149
if deps.cuda_aware_mpi_enabled or engine == "numpy":
149150
send_shapes = base_comm.allgather(send_buf.shape)
@@ -165,15 +166,15 @@ def mpi_send(base_comm: MPI.Comm,
165166
engine: Optional[str] = "numpy",
166167
) -> None:
167168
"""MPI_Send/send
168-
169-
Dispatch send routine based on type of input and availability of
169+
170+
Dispatch send routine based on type of input and availability of
170171
CUDA-Aware MPI
171172
172173
Parameters
173174
----------
174175
base_comm : :obj:`MPI.Comm`
175176
Base MPI Communicator.
176-
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
177+
send_buf : :obj:`numpy.ndarray` or :obj:`cupy.ndarray`
177178
The array containing data to send.
178179
dest: :obj:`int`
179180
The rank of the destination CPU/GPU device.
@@ -183,7 +184,6 @@ def mpi_send(base_comm: MPI.Comm,
183184
Tag of the message to be sent.
184185
engine : :obj:`str`, optional
185186
Engine used to store array (``numpy`` or ``cupy``)
186-
187187
"""
188188
if deps.cuda_aware_mpi_enabled or engine == "numpy":
189189
# Determine MPI type based on array dtype
@@ -195,11 +195,12 @@ def mpi_send(base_comm: MPI.Comm,
195195
# Uses CuPy without CUDA-aware MPI
196196
base_comm.send(send_buf, dest, tag)
197197

198+
198199
def mpi_recv(base_comm: MPI.Comm,
199-
recv_buf=None, source=0, count=None, tag=0,
200-
engine: Optional[str] = "numpy") -> np.ndarray:
200+
recv_buf=None, source=0, count=None, tag=0,
201+
engine: Optional[str] = "numpy") -> np.ndarray:
201202
""" MPI_Recv/recv
202-
Dispatch receive routine based on type of input and availability of
203+
Dispatch receive routine based on type of input and availability of
203204
CUDA-Aware MPI
204205
205206
Parameters
@@ -216,7 +217,6 @@ def mpi_recv(base_comm: MPI.Comm,
216217
Tag of the message to be sent.
217218
engine : :obj:`str`, optional
218219
Engine used to store array (``numpy`` or ``cupy``)
219-
220220
"""
221221
if deps.cuda_aware_mpi_enabled or engine == "numpy":
222222
ncp = get_module(engine)
@@ -233,4 +233,3 @@ def mpi_recv(base_comm: MPI.Comm,
233233
# Uses CuPy without CUDA-aware MPI
234234
recv_buf = base_comm.recv(source=source, tag=tag)
235235
return recv_buf
236-

pylops_mpi/utils/_nccl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def _nccl_sync():
6666
return
6767
cp.cuda.runtime.deviceSynchronize()
6868

69+
6970
def mpi_op_to_nccl(mpi_op) -> NcclOp:
7071
""" Map MPI reduction operation to NCCL equivalent
7172

0 commit comments

Comments
 (0)