Skip to content

Commit 78c1ffd

Browse files
author
niushengxiao
committed
use list
1 parent 8078fd0 commit 78c1ffd

File tree

3 files changed

+27
-42
lines changed

3 files changed

+27
-42
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,11 @@ def __init__(self, kvargs):
7878
self._init_mem_manager()
7979
self._init_weights()
8080

81-
self.stream1 = torch.cuda.Stream()
82-
self.stream2 = torch.cuda.Stream()
81+
self.stream_num = 2
82+
self.graph = [None] * self.stream_num
83+
self.stream = [None] * self.stream_num
84+
for i in range(self.stream_num):
85+
self.stream[i] = torch.cuda.Stream()
8386
self._init_kv_move_buffer()
8487
self._check_mem_size()
8588
self._init_req_manager()
@@ -205,16 +208,11 @@ def _init_datatype(self):
205208
raise ValueError(f"Unsupport datatype {self.data_type}!")
206209

207210
def _init_cudagraph(self):
208-
self.graph = (
209-
None if self.disable_cudagraph else CudaGraph(self.stream1, self.graph_max_batch_size, self.graph_max_len_in_batch)
210-
)
211-
self.graph2 = (
212-
None if self.disable_cudagraph else CudaGraph(self.stream2, self.graph_max_batch_size, self.graph_max_len_in_batch)
213-
)
214-
if self.graph is not None:
215-
self.graph.warmup(self, 0)
216-
if self.graph2 is not None:
217-
self.graph2.warmup(self, 1)
211+
for i in range(self.stream_num):
212+
self.graph[i] = (
213+
None if self.disable_cudagraph else CudaGraph(self.stream[i], self.graph_max_batch_size, self.graph_max_len_in_batch)
214+
)
215+
self.graph[i].warmup(self, i)
218216

219217
def _init_custom(self):
220218
pass
@@ -363,7 +361,7 @@ def _decode(
363361
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
364362

365363
infer_state.init_some_extra_state(self, input_ids)
366-
graph = self.graph if all_reduce_id == 0 else self.graph2
364+
graph = self.graph[all_reduce_id]
367365
if graph is not None and graph.can_run(batch_size, max_len_in_batch):
368366
if graph.need_capture(batch_size):
369367
infer_state.is_cuda_graph = True

lightllm/distributed/communication_op.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,34 +51,22 @@
5151

5252
class CustomCommunicationOp:
5353
def __init__(self):
54-
self.vllm_reduce1 = None
55-
self.vllm_reduce2 = None
56-
self.custom_gather = None
57-
self.custom_gather2 = None
54+
self.reduce_num = 2
55+
self.vllm_reduce = [None] * self.reduce_num
56+
self.custom_gather = [None] * self.reduce_num
5857
self.device_group = None
5958

6059
@contextmanager
6160
def lightllm_capture_graph(self, all_reduce_id):
62-
if all_reduce_id == 0:
63-
if self.vllm_reduce1 is not None:
64-
with self.vllm_reduce1.capture():
65-
if self.custom_gather is not None:
66-
with self.custom_gather.capture():
67-
yield
68-
else:
61+
if self.vllm_reduce[all_reduce_id] is not None:
62+
with self.vllm_reduce[all_reduce_id].capture():
63+
if self.custom_gather[all_reduce_id] is not None:
64+
with self.custom_gather[all_reduce_id].capture():
6965
yield
70-
else:
71-
yield
66+
else:
67+
yield
7268
else:
73-
if self.vllm_reduce2 is not None:
74-
with self.vllm_reduce2.capture():
75-
if self.custom_gather2 is not None:
76-
with self.custom_gather2.capture():
77-
yield
78-
else:
79-
yield
80-
else:
81-
yield
69+
yield
8270

8371
def set_custom_reduce(self):
8472
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "True").upper() in ["ON", "TRUE", "1"]
@@ -97,17 +85,16 @@ def set_custom_reduce(self):
9785
self.device_group = dist.new_group(ranks, backend="nccl")
9886

9987
if ENABLE_VLLM_REDUCE and HAS_VLLM:
100-
cpu_group1 = dist.new_group(ranks, backend="gloo")
101-
self.vllm_reduce1 = CustomAllreduce(cpu_group1, torch.cuda.current_device())
102-
cpu_group2 = dist.new_group(ranks, backend="gloo")
103-
self.vllm_reduce2 = CustomAllreduce(cpu_group2, torch.cuda.current_device())
88+
cpu_group = [dist.new_group(ranks, backend="gloo")] * self.reduce_num
89+
for i in range(self.reduce_num):
90+
self.vllm_reduce[i] = CustomAllreduce(cpu_group[i], torch.cuda.current_device())
10491
logger.info("Enable VLLM ALLReduce.")
10592

10693
def _all_reduce_closure(input_, op=ReduceOp.SUM, group=self.device_group, async_op=False, all_reduce_id=0):
10794
if op != ReduceOp.SUM or async_op:
10895
original_all_reduce(input_, op, group, async_op)
10996
else:
110-
vllm_reduce = self.vllm_reduce1 if all_reduce_id == 0 else self.vllm_reduce2
97+
vllm_reduce = self.vllm_reduce[all_reduce_id]
11198
if vllm_reduce is not None and vllm_reduce.should_custom_ar(input_):
11299
input_.data = vllm_reduce.custom_all_reduce(input_)
113100
else:

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def decode(self):
108108
# logits = self.model.forward(**kwargs)
109109
if kwargs["batch_size"] > 1:
110110
kwargs1, kwargs2 = split_kwargs(**kwargs)
111-
with torch.cuda.stream(self.model.stream1):
111+
with torch.cuda.stream(self.model.stream[0]):
112112
logits1 = self.model.forward(**kwargs1)
113-
with torch.cuda.stream(self.model.stream2):
113+
with torch.cuda.stream(self.model.stream[1]):
114114
logits2 = self.model.forward(**kwargs2)
115115
torch.cuda.synchronize()
116116
logits = torch.cat([logits1, logits2], dim=0)

0 commit comments

Comments
 (0)