Skip to content

Commit 5071e7e

Browse files
author
niushengxiao
committed
fix
1 parent cf9df54 commit 5071e7e

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

lightllm/server/router/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ async def loop_for_netio_req(self):
437437
if isinstance(recv_req, GroupReqIndexes):
438438
queue0_len = self.req_queue[0].get_wait_req_num()
439439
queue1_len = self.req_queue[1].get_wait_req_num()
440-
self.add_req(recv_req, 1)# if queue0_len <= queue1_len else 1)
440+
self.add_req(recv_req, 0 if queue0_len <= queue1_len else 1)
441441
else:
442442
assert False, f"Error Req Inf {recv_req}"
443443

@@ -473,7 +473,7 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
473473
pipe_writer.send("init ok")
474474
loop = asyncio.new_event_loop()
475475
asyncio.set_event_loop(loop)
476-
# loop.create_task(router.loop_for_fwd(0))
476+
loop.create_task(router.loop_for_fwd(0))
477477
loop.create_task(router.loop_for_fwd(1))
478478
loop.run_until_complete(router.loop_for_netio_req())
479479
return

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ def split_kwargs_n(
8989
b_ready_cache_len: torch.Tensor = None,
9090
multimodal_params=None,
9191
is_prefill=True,
92-
split_n=2):
92+
split_n=2,
93+
run_reqs=None):
9394

9495
kwargs = [None] * split_n
96+
run_reqs_list = [None] * split_n
9597

9698
# 计算每个分片的批次大小
9799
batch_per_split = [batch_size // split_n] * split_n
@@ -124,11 +126,13 @@ def split_kwargs_n(
124126
token_start = cumulative_tokens
125127
token_end = token_start + split_tokens
126128
split_input_ids = input_ids[token_start:token_end]
129+
reqs = run_reqs[token_start:token_end]
127130
split_mem_indexes = mem_indexes[token_start:token_end]
128131
else:
129132
# 在decode阶段,根据批次分割
130133
split_input_ids = input_ids[start_idx:end_idx]
131134
split_mem_indexes = mem_indexes[start_idx:end_idx]
135+
reqs = run_reqs[start_idx:end_idx]
132136

133137
# 计算此分片的其他参数
134138
split_max_len = split_b_seq_len.max().item() if len(split_b_seq_len) > 0 else 0
@@ -139,6 +143,7 @@ def split_kwargs_n(
139143
if b_ready_cache_len is not None:
140144
split_b_ready_cache_len = b_ready_cache_len[start_idx:end_idx]
141145

146+
run_reqs_list[i] = reqs
142147
# 创建kwargs字典
143148
kwargs[i] = {
144149
"batch_size": len(split_b_req_idx),
@@ -161,7 +166,7 @@ def split_kwargs_n(
161166
# 更新累计token数
162167
cumulative_tokens += split_tokens
163168

164-
return kwargs
169+
return kwargs, run_reqs_list
165170

166171
class ContinuesBatchBackend(ModeBackend):
167172
def __init__(self) -> None:
@@ -175,7 +180,10 @@ def prefill(self, reqs: List[Tuple], stream_id):
175180
with torch.cuda.stream(self.model.stream[stream_id]):
176181
logits = self.model.forward(**kwargs)
177182
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
178-
torch.cuda.current_stream().synchronize()
183+
next_token_ids = next_token_ids.detach().cpu().numpy()
184+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
185+
self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
186+
# torch.cuda.current_stream().synchronize()
179187

180188
# logits = self.model.forward(**kwargs)
181189

@@ -193,10 +201,10 @@ def prefill(self, reqs: List[Tuple], stream_id):
193201
# logits = self.model.forward(**kwargs)
194202

195203
# next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
196-
next_token_ids = next_token_ids.detach().cpu().numpy()
197-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
204+
# next_token_ids = next_token_ids.detach().cpu().numpy()
205+
# next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
198206

199-
return self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
207+
# self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
200208
# return
201209

202210
def decode(self, stream_id):
@@ -207,13 +215,15 @@ def decode(self, stream_id):
207215
with torch.cuda.stream(self.model.stream[stream_id]):
208216
logits = self.model.forward(**kwargs)
209217
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
210-
torch.cuda.current_stream().synchronize()
218+
next_token_ids = next_token_ids.detach().cpu().numpy()
219+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
220+
self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
211221

212222
# logits = self.model.forward(**kwargs)
213223

214224
# split_n = self.model.stream_num
215225
# if kwargs["batch_size"] > split_n - 1:
216-
# kwargs_list = split_kwargs_n(**kwargs, split_n=split_n)
226+
# kwargs_list, run_reqs_list = split_kwargs_n(**kwargs, split_n=split_n, run_reqs=run_reqs)
217227
# logits = [None] * split_n
218228
# for i in range(split_n):
219229
# with torch.cuda.stream(self.model.stream[i]):
@@ -225,11 +235,10 @@ def decode(self, stream_id):
225235
# logits = self.model.forward(**kwargs)
226236

227237
# next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
228-
next_token_ids = next_token_ids.detach().cpu().numpy()
229-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
230-
231-
self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
232-
return stream_id
238+
# next_token_ids = next_token_ids.detach().cpu().numpy()
239+
# next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
240+
# self.post_handel(run_reqs, next_token_ids, next_token_logprobs, stream_id)
241+
return
233242

234243
def post_handel(self, run_reqs: List[InferReq], next_token_ids, next_token_logprobs, stream_id):
235244
finished_req_ids = []

0 commit comments

Comments
 (0)