Skip to content

Commit d45a1b2

Browse files
author
Weichao Luo
committed
fix style.
1 parent 4c8fd1a commit d45a1b2

File tree

9 files changed

+318
-226
lines changed

9 files changed

+318
-226
lines changed

lightllm/common/deepseek2_mem_manager.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
4242
return
4343

4444
def send_to_decode_node(
45-
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int,
46-
nccl_comm: PyNcclCommunicator
45+
self,
46+
move_tasks: List[KVMoveTask],
47+
mem_managers: List["Deepseek2MemoryManager"],
48+
dp_size_in_node: int,
49+
nccl_comm: PyNcclCommunicator,
4750
):
4851
assert dp_size_in_node == 1
4952

@@ -69,8 +72,11 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
6972
return move_buffer
7073

7174
def receive_from_prefill_node(
72-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
73-
nccl_comm: PyNcclCommunicator
75+
self,
76+
move_tasks: List[KVMoveTask],
77+
mem_managers: List["MemoryManager"],
78+
dp_size_in_node: int,
79+
nccl_comm: PyNcclCommunicator,
7480
):
7581
assert dp_size_in_node == 1
7682

@@ -102,8 +108,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
102108
return
103109

104110
def send_to_decode_node_p2p(
105-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
106-
nccl_comm: PyNcclCommunicator
111+
self,
112+
move_tasks: List[KVMoveTask],
113+
mem_managers: List["MemoryManager"],
114+
dp_size_in_node: int,
115+
nccl_comm: PyNcclCommunicator,
107116
):
108117
"""
109118
使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -155,8 +164,11 @@ def _get_kv_move_data_p2p(
155164
return move_buffer
156165

157166
def receive_from_prefill_node_p2p(
158-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
159-
nccl_comm: PyNcclCommunicator
167+
self,
168+
move_tasks: List[KVMoveTask],
169+
mem_managers: List["MemoryManager"],
170+
dp_size_in_node: int,
171+
nccl_comm: PyNcclCommunicator,
160172
):
161173
if not hasattr(self, "mem_ptrs_dict"):
162174
self.mem_ptrs_dict = {}

lightllm/common/mem_manager.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,11 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8787
return
8888

8989
def send_to_decode_node(
90-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
91-
nccl_comm: PyNcclCommunicator
90+
self,
91+
move_tasks: List[KVMoveTask],
92+
mem_managers: List["MemoryManager"],
93+
dp_size_in_node: int,
94+
nccl_comm: PyNcclCommunicator,
9295
):
9396
assert dp_size_in_node == 1
9497

@@ -124,7 +127,10 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
124127
return move_buffer
125128

126129
def receive_from_prefill_node(
127-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
130+
self,
131+
move_tasks: List[KVMoveTask],
132+
mem_managers: List["MemoryManager"],
133+
dp_size_in_node: int,
128134
nccl_comm: PyNcclCommunicator,
129135
):
130136
assert dp_size_in_node == 1
@@ -158,8 +164,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
158164
return
159165

160166
def send_to_decode_node_p2p(
161-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
162-
nccl_comm: PyNcclCommunicator
167+
self,
168+
move_tasks: List[KVMoveTask],
169+
mem_managers: List["MemoryManager"],
170+
dp_size_in_node: int,
171+
nccl_comm: PyNcclCommunicator,
163172
):
164173
"""
165174
使用 p2p triton kernel 进行数据复制和传输的实现方式。
@@ -190,7 +199,10 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
190199
return move_buffer
191200

192201
def receive_from_prefill_node_p2p(
193-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int,
202+
self,
203+
move_tasks: List[KVMoveTask],
204+
mem_managers: List["MemoryManager"],
205+
dp_size_in_node: int,
194206
nccl_comm: PyNcclCommunicator,
195207
):
196208
assert dp_size_in_node == 1

lightllm/distributed/pynccl.py

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,27 @@
3333
from torch.distributed import ProcessGroup, ReduceOp, TCPStore
3434

3535
from lightllm.distributed.pynccl_wrapper import (
36-
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
37-
ncclRedOpTypeEnum, ncclUniqueId)
36+
NCCLLibrary,
37+
buffer_type,
38+
cudaStream_t,
39+
ncclComm_t,
40+
ncclDataTypeEnum,
41+
ncclRedOpTypeEnum,
42+
ncclUniqueId,
43+
)
3844

3945
logger = logging.getLogger(__name__)
4046

4147
_current_stream = None
4248

49+
4350
def current_stream() -> torch.cuda.Stream:
4451
global _current_stream
4552
if _current_stream is None:
4653
_current_stream = torch.cuda.current_stream()
4754
return _current_stream
4855

56+
4957
@dataclasses.dataclass
5058
class StatelessP2PProcessGroup:
5159
"""A dataclass to hold a metadata store, and the rank, world_size of the
@@ -94,18 +102,13 @@ def expire_data(self):
94102

95103
def recv_obj(self) -> Any:
96104
"""Receive an object from a source rank."""
97-
obj = pickle.loads(
98-
self.store.get(
99-
f"send_to/{self.dest_id}/{self.recv_src_counter}"))
105+
obj = pickle.loads(self.store.get(f"send_to/{self.dest_id}/{self.recv_src_counter}"))
100106
self.recv_src_counter += 1
101107
return obj
102108

103109
@staticmethod
104110
def create(
105-
src_id: int,
106-
dest_id: int,
107-
is_server: bool,
108-
store: torch._C._distributed_c10d.Store
111+
src_id: int, dest_id: int, is_server: bool, store: torch._C._distributed_c10d.Store
109112
) -> "StatelessP2PProcessGroup":
110113
"""A replacement for `torch.distributed.init_process_group` that does not
111114
pollute the global state.
@@ -121,12 +124,11 @@ def create(
121124
used for exchanging metadata. With this function, process A and process B
122125
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
123126
C, and D can call `StatelessProcessGroup.create` to form another group.
124-
""" # noqa
127+
""" # noqa
125128
return StatelessP2PProcessGroup(src_id=src_id, dest_id=dest_id, is_server=is_server, store=store)
126129

127130

128131
class PyNcclCommunicator:
129-
130132
def __init__(
131133
self,
132134
group: Union[ProcessGroup, StatelessP2PProcessGroup],
@@ -146,8 +148,9 @@ def __init__(
146148
"""
147149
if not isinstance(group, StatelessP2PProcessGroup):
148150
assert dist.is_initialized()
149-
assert dist.get_backend(group) != dist.Backend.NCCL, (
150-
"PyNcclCommunicator should be attached to a non-NCCL group.")
151+
assert (
152+
dist.get_backend(group) != dist.Backend.NCCL
153+
), "PyNcclCommunicator should be attached to a non-NCCL group."
151154
# note: this rank is the rank in the group
152155
self.rank = dist.get_rank(group)
153156
self.world_size = dist.get_world_size(group)
@@ -207,8 +210,7 @@ def __init__(
207210
# `torch.cuda.device` is a context manager that changes the
208211
# current cuda device to the specified one
209212
with torch.cuda.device(device):
210-
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
211-
self.world_size, self.unique_id, self.rank)
213+
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(self.world_size, self.unique_id, self.rank)
212214

213215
stream = current_stream()
214216
# A small all_reduce for warmup.
@@ -220,103 +222,120 @@ def __init__(
220222
def destroy(self):
221223
self.nccl.ncclCommDestroy(self.comm)
222224

223-
def all_reduce(self,
224-
in_tensor: torch.Tensor,
225-
op: ReduceOp = ReduceOp.SUM,
226-
stream=None) -> torch.Tensor:
225+
def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor:
227226
if self.disabled:
228227
return None
229228
# nccl communicator created on a specific device
230229
# will only work on tensors on the same device
231230
# otherwise it will cause "illegal memory access"
232231
assert in_tensor.device == self.device, (
233232
f"this nccl communicator is created to work on {self.device}, "
234-
f"but the input tensor is on {in_tensor.device}")
233+
f"but the input tensor is on {in_tensor.device}"
234+
)
235235

236236
out_tensor = torch.empty_like(in_tensor)
237237

238238
if stream is None:
239239
stream = current_stream()
240-
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
241-
buffer_type(out_tensor.data_ptr()),
242-
in_tensor.numel(),
243-
ncclDataTypeEnum.from_torch(in_tensor.dtype),
244-
ncclRedOpTypeEnum.from_torch(op), self.comm,
245-
cudaStream_t(stream.cuda_stream))
240+
self.nccl.ncclAllReduce(
241+
buffer_type(in_tensor.data_ptr()),
242+
buffer_type(out_tensor.data_ptr()),
243+
in_tensor.numel(),
244+
ncclDataTypeEnum.from_torch(in_tensor.dtype),
245+
ncclRedOpTypeEnum.from_torch(op),
246+
self.comm,
247+
cudaStream_t(stream.cuda_stream),
248+
)
246249
return out_tensor
247250

248-
def all_gather(self,
249-
output_tensor: torch.Tensor,
250-
input_tensor: torch.Tensor,
251-
stream=None):
251+
def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None):
252252
if self.disabled:
253253
return
254254
# nccl communicator created on a specific device
255255
# will only work on tensors on the same device
256256
# otherwise it will cause "illegal memory access"
257257
assert input_tensor.device == self.device, (
258258
f"this nccl communicator is created to work on {self.device}, "
259-
f"but the input tensor is on {input_tensor.device}")
259+
f"but the input tensor is on {input_tensor.device}"
260+
)
260261
if stream is None:
261262
stream = current_stream()
262263
self.nccl.ncclAllGather(
263264
buffer_type(input_tensor.data_ptr()),
264-
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
265-
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
266-
cudaStream_t(stream.cuda_stream))
267-
268-
def reduce_scatter(self,
269-
output_tensor: torch.Tensor,
270-
input_tensor: torch.Tensor,
271-
op: ReduceOp = ReduceOp.SUM,
272-
stream=None):
265+
buffer_type(output_tensor.data_ptr()),
266+
input_tensor.numel(),
267+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
268+
self.comm,
269+
cudaStream_t(stream.cuda_stream),
270+
)
271+
272+
def reduce_scatter(
273+
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
274+
):
273275
if self.disabled:
274276
return
275277
# nccl communicator created on a specific device
276278
# will only work on tensors on the same device
277279
# otherwise it will cause "illegal memory access"
278280
assert input_tensor.device == self.device, (
279281
f"this nccl communicator is created to work on {self.device}, "
280-
f"but the input tensor is on {input_tensor.device}")
282+
f"but the input tensor is on {input_tensor.device}"
283+
)
281284
if stream is None:
282285
stream = current_stream()
283286
self.nccl.ncclReduceScatter(
284287
buffer_type(input_tensor.data_ptr()),
285-
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
288+
buffer_type(output_tensor.data_ptr()),
289+
output_tensor.numel(),
286290
ncclDataTypeEnum.from_torch(input_tensor.dtype),
287-
ncclRedOpTypeEnum.from_torch(op), self.comm,
288-
cudaStream_t(stream.cuda_stream))
291+
ncclRedOpTypeEnum.from_torch(op),
292+
self.comm,
293+
cudaStream_t(stream.cuda_stream),
294+
)
289295

290296
def send(self, tensor: torch.Tensor, dst: int, stream=None):
291297
if self.disabled:
292298
return
293299
assert tensor.device == self.device, (
294300
f"this nccl communicator is created to work on {self.device}, "
295-
f"but the input tensor is on {tensor.device}")
301+
f"but the input tensor is on {tensor.device}"
302+
)
296303
if stream is None:
297304
stream = current_stream()
298-
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
299-
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
300-
self.comm, cudaStream_t(stream.cuda_stream))
305+
self.nccl.ncclSend(
306+
buffer_type(tensor.data_ptr()),
307+
tensor.numel(),
308+
ncclDataTypeEnum.from_torch(tensor.dtype),
309+
dst,
310+
self.comm,
311+
cudaStream_t(stream.cuda_stream),
312+
)
301313

302314
def recv(self, tensor: torch.Tensor, src: int, stream=None):
303315
if self.disabled:
304316
return
305317
assert tensor.device == self.device, (
306318
f"this nccl communicator is created to work on {self.device}, "
307-
f"but the input tensor is on {tensor.device}")
319+
f"but the input tensor is on {tensor.device}"
320+
)
308321
if stream is None:
309322
stream = current_stream()
310-
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
311-
ncclDataTypeEnum.from_torch(tensor.dtype), src,
312-
self.comm, cudaStream_t(stream.cuda_stream))
323+
self.nccl.ncclRecv(
324+
buffer_type(tensor.data_ptr()),
325+
tensor.numel(),
326+
ncclDataTypeEnum.from_torch(tensor.dtype),
327+
src,
328+
self.comm,
329+
cudaStream_t(stream.cuda_stream),
330+
)
313331

314332
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
315333
if self.disabled:
316334
return
317335
assert tensor.device == self.device, (
318336
f"this nccl communicator is created to work on {self.device}, "
319-
f"but the input tensor is on {tensor.device}")
337+
f"but the input tensor is on {tensor.device}"
338+
)
320339
if stream is None:
321340
stream = current_stream()
322341
if src == self.rank:
@@ -326,7 +345,12 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
326345
else:
327346
sendbuff = buffer_type()
328347
recvbuff = buffer_type(tensor.data_ptr())
329-
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
330-
ncclDataTypeEnum.from_torch(tensor.dtype), src,
331-
self.comm, cudaStream_t(stream.cuda_stream))
332-
348+
self.nccl.ncclBroadcast(
349+
sendbuff,
350+
recvbuff,
351+
tensor.numel(),
352+
ncclDataTypeEnum.from_torch(tensor.dtype),
353+
src,
354+
self.comm,
355+
cudaStream_t(stream.cuda_stream),
356+
)

0 commit comments

Comments
 (0)