Skip to content

Commit 733e371

Browse files
committed
reformat
1 parent 19d61a4 commit 733e371

File tree

5 files changed

+127
-134
lines changed

5 files changed

+127
-134
lines changed

lightllm/distributed/device_communicators/pynccl.py

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Adapted from
1+
# Adapted from
22
# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/distributed/device_communicators/pynccl.py
33
# of the vllm-project/vllm GitHub repository.
44
#
@@ -26,15 +26,20 @@
2626
from torch.distributed import ProcessGroup, ReduceOp
2727

2828
from lightllm.distributed.device_communicators.pynccl_wrapper import (
29-
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
30-
ncclRedOpTypeEnum, ncclUniqueId)
29+
NCCLLibrary,
30+
buffer_type,
31+
cudaStream_t,
32+
ncclComm_t,
33+
ncclDataTypeEnum,
34+
ncclRedOpTypeEnum,
35+
ncclUniqueId,
36+
)
3137
from lightllm.utils.log_utils import init_logger
3238

3339
logger = init_logger(__name__)
3440

3541

3642
class PyNcclCommunicator:
37-
3843
def __init__(
3944
self,
4045
group: ProcessGroup,
@@ -53,8 +58,9 @@ def __init__(
5358
is bind to a unique device.
5459
"""
5560
assert dist.is_initialized()
56-
assert dist.get_backend(group) != dist.Backend.NCCL, (
57-
"PyNcclCommunicator should be attached to a non-NCCL group.")
61+
assert (
62+
dist.get_backend(group) != dist.Backend.NCCL
63+
), "PyNcclCommunicator should be attached to a non-NCCL group."
5864
self.group = group
5965
# note: this rank is the rank in the group
6066
self.rank = dist.get_rank(group)
@@ -105,8 +111,7 @@ def __init__(
105111
# `torch.cuda.device` is a context manager that changes the
106112
# current cuda device to the specified one
107113
with torch.cuda.device(device):
108-
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
109-
self.world_size, self.unique_id, self.rank)
114+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(self.world_size, self.unique_id, self.rank)
110115
self.stream = torch.cuda.Stream()
111116

112117
# A small all_reduce for warmup.
@@ -120,54 +125,66 @@ def __init__(
120125
# when we are using CUDA graph.
121126
self.disabled = True
122127

123-
def all_reduce(self,
124-
tensor: torch.Tensor,
125-
op: ReduceOp = ReduceOp.SUM,
126-
stream=None):
128+
def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None):
127129
if self.disabled:
128130
return
129131
# nccl communicator created on a specific device
130132
# will only work on tensors on the same device
131133
# otherwise it will cause "illegal memory access"
132134
assert tensor.device == self.device, (
133135
f"this nccl communicator is created to work on {self.device}, "
134-
f"but the input tensor is on {tensor.device}")
136+
f"but the input tensor is on {tensor.device}"
137+
)
135138
if stream is None:
136139
stream = self.stream
137-
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
138-
buffer_type(tensor.data_ptr()), tensor.numel(),
139-
ncclDataTypeEnum.from_torch(tensor.dtype),
140-
ncclRedOpTypeEnum.from_torch(op), self.comm,
141-
cudaStream_t(stream.cuda_stream))
140+
self.nccl.ncclAllReduce(
141+
buffer_type(tensor.data_ptr()),
142+
buffer_type(tensor.data_ptr()),
143+
tensor.numel(),
144+
ncclDataTypeEnum.from_torch(tensor.dtype),
145+
ncclRedOpTypeEnum.from_torch(op),
146+
self.comm,
147+
cudaStream_t(stream.cuda_stream),
148+
)
142149

143150
def send(self, tensor: torch.Tensor, dst: int, stream=None):
144151
if self.disabled:
145152
return
146153
assert tensor.device == self.device, (
147154
f"this nccl communicator is created to work on {self.device}, "
148-
f"but the input tensor is on {tensor.device}")
155+
f"but the input tensor is on {tensor.device}"
156+
)
149157
if stream is None:
150158
stream = self.stream
151-
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
152-
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
153-
self.comm, cudaStream_t(stream.cuda_stream))
159+
self.nccl.ncclSend(
160+
buffer_type(tensor.data_ptr()),
161+
tensor.numel(),
162+
ncclDataTypeEnum.from_torch(tensor.dtype),
163+
dst,
164+
self.comm,
165+
cudaStream_t(stream.cuda_stream),
166+
)
154167

155168
def recv(self, tensor: torch.Tensor, src: int, stream=None):
156169
if self.disabled:
157170
return
158171
assert tensor.device == self.device, (
159172
f"this nccl communicator is created to work on {self.device}, "
160-
f"but the input tensor is on {tensor.device}")
173+
f"but the input tensor is on {tensor.device}"
174+
)
161175
if stream is None:
162176
stream = self.stream
163-
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
164-
ncclDataTypeEnum.from_torch(tensor.dtype), src,
165-
self.comm, cudaStream_t(stream.cuda_stream))
177+
self.nccl.ncclRecv(
178+
buffer_type(tensor.data_ptr()),
179+
tensor.numel(),
180+
ncclDataTypeEnum.from_torch(tensor.dtype),
181+
src,
182+
self.comm,
183+
cudaStream_t(stream.cuda_stream),
184+
)
166185

167186
@contextmanager
168-
def change_state(self,
169-
enable: Optional[bool] = None,
170-
stream: Optional[torch.cuda.Stream] = None):
187+
def change_state(self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None):
171188
"""
172189
A context manager to change the state of the communicator.
173190
"""

lightllm/distributed/device_communicators/pynccl_wrapper.py

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Adapted from
1+
# Adapted from
22
# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/distributed/device_communicators/pynccl_wrapper.py
33
# of the vllm-project/vllm GitHub repository.
44
#
@@ -146,46 +146,43 @@ class NCCLLibrary:
146146
# const char* ncclGetErrorString(ncclResult_t result)
147147
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
148148
# ncclResult_t ncclGetVersion(int *version);
149-
Function("ncclGetVersion", ncclResult_t,
150-
[ctypes.POINTER(ctypes.c_int)]),
149+
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
151150
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
152-
Function("ncclGetUniqueId", ncclResult_t,
153-
[ctypes.POINTER(ncclUniqueId)]),
151+
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
154152
# ncclResult_t ncclCommInitRank(
155153
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
156154
# note that ncclComm_t is a pointer type, so the first argument
157155
# is a pointer to a pointer
158-
Function("ncclCommInitRank", ncclResult_t, [
159-
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
160-
ctypes.c_int
161-
]),
156+
Function(
157+
"ncclCommInitRank", ncclResult_t, [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int]
158+
),
162159
# ncclResult_t ncclAllReduce(
163160
# const void* sendbuff, void* recvbuff, size_t count,
164161
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
165162
# cudaStream_t stream);
166163
# note that cudaStream_t is a pointer type, so the last argument
167164
# is a pointer
168-
Function("ncclAllReduce", ncclResult_t, [
169-
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
170-
ncclRedOp_t, ncclComm_t, cudaStream_t
171-
]),
172-
165+
Function(
166+
"ncclAllReduce",
167+
ncclResult_t,
168+
[buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t],
169+
),
173170
# ncclResult_t ncclSend(
174171
# const void* sendbuff, size_t count, ncclDataType_t datatype,
175172
# int dest, ncclComm_t comm, cudaStream_t stream);
176-
Function("ncclSend", ncclResult_t, [
177-
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
178-
ncclComm_t, cudaStream_t
179-
]),
180-
173+
Function(
174+
"ncclSend",
175+
ncclResult_t,
176+
[buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t],
177+
),
181178
# ncclResult_t ncclRecv(
182179
# void* recvbuff, size_t count, ncclDataType_t datatype,
183180
# int src, ncclComm_t comm, cudaStream_t stream);
184-
Function("ncclRecv", ncclResult_t, [
185-
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
186-
ncclComm_t, cudaStream_t
187-
]),
188-
181+
Function(
182+
"ncclRecv",
183+
ncclResult_t,
184+
[buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, ncclComm_t, cudaStream_t],
185+
),
189186
# be cautious! this is a collective call, it will block until all
190187
# processes in the communicator have called this function.
191188
# because Python object destruction can happen in random order,
@@ -219,8 +216,10 @@ def __init__(self, so_file: Optional[str] = None):
219216
"or it does not support the current platform %s."
220217
"If you already have the library, please set the "
221218
"environment variable VLLM_NCCL_SO_PATH"
222-
" to point to the correct nccl library path.", so_file,
223-
platform.platform())
219+
" to point to the correct nccl library path.",
220+
so_file,
221+
platform.platform(),
222+
)
224223
raise e
225224

226225
if so_file not in NCCLLibrary.path_to_dict_mapping:
@@ -253,45 +252,51 @@ def ncclGetVersion(self) -> str:
253252

254253
def ncclGetUniqueId(self) -> ncclUniqueId:
255254
unique_id = ncclUniqueId()
256-
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
257-
ctypes.byref(unique_id)))
255+
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
258256
return unique_id
259257

260-
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
261-
rank: int) -> ncclComm_t:
258+
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, rank: int) -> ncclComm_t:
262259
comm = ncclComm_t()
263-
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
264-
world_size, unique_id,
265-
rank))
260+
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), world_size, unique_id, rank))
266261
return comm
267262

268-
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
269-
count: int, datatype: int, op: int, comm: ncclComm_t,
270-
stream: cudaStream_t) -> None:
263+
def ncclAllReduce(
264+
self,
265+
sendbuff: buffer_type,
266+
recvbuff: buffer_type,
267+
count: int,
268+
datatype: int,
269+
op: int,
270+
comm: ncclComm_t,
271+
stream: cudaStream_t,
272+
) -> None:
271273
# `datatype` actually should be `ncclDataType_t`
272274
# and `op` should be `ncclRedOp_t`
273275
# both are aliases of `ctypes.c_int`
274276
# when we pass int to a function, it will be converted to `ctypes.c_int`
275277
# by ctypes automatically
276-
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
277-
datatype, op, comm,
278-
stream))
278+
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, datatype, op, comm, stream))
279279

280-
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
281-
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
282-
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
283-
dest, comm, stream))
280+
def ncclSend(
281+
self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t
282+
) -> None:
283+
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream))
284284

285-
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
286-
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
287-
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
288-
comm, stream))
285+
def ncclRecv(
286+
self, recvbuff: buffer_type, count: int, datatype: int, src: int, comm: ncclComm_t, stream: cudaStream_t
287+
) -> None:
288+
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream))
289289

290290
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
291291
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
292292

293293

294294
__all__ = [
295-
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
296-
"ncclComm_t", "cudaStream_t", "buffer_type"
295+
"NCCLLibrary",
296+
"ncclDataTypeEnum",
297+
"ncclRedOpTypeEnum",
298+
"ncclUniqueId",
299+
"ncclComm_t",
300+
"cudaStream_t",
301+
"buffer_type",
297302
]

0 commit comments

Comments
 (0)