Skip to content

Commit aa85395

Browse files
committed
chunked for diversemode
1 parent 1bc2342 commit aa85395

File tree

2 files changed

+61
-34
lines changed
  • lightllm/server/router

2 files changed

+61
-34
lines changed

lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,24 @@ def __init__(self) -> None:
2424
def init_custom(self):
2525
pass
2626

27-
def build_group(self, req_ids: List[int]):
28-
for r_id in req_ids:
29-
req: InferReq = g_infer_context.requests_mapping[r_id]
27+
def build_group(self, reqs: List[InferReq]):
28+
for req in reqs:
3029
group_req_id = req.shm_req.group_req_id
3130
if group_req_id not in g_infer_context.group_mapping:
3231
g_infer_context.group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id)
33-
g_infer_context.group_mapping[group_req_id].add_req(r_id)
32+
g_infer_context.group_mapping[group_req_id].add_req(req.req_id)
3433

3534
def diverse_copy(self, groups: List[InferReqGroup]):
3635
batch_idx = []
3736
run_reqs = []
3837
for i in range(len(groups)):
3938
req_group = groups[i]
4039
best_of = req_group.best_of()
41-
if best_of > 1:
40+
_0_req_obj = req_group.get_req(0)
41+
if (
42+
best_of > 1 and
43+
_0_req_obj.get_chuncked_input_token_len() == _0_req_obj.get_cur_total_len()
44+
):
4245
req_group.diverse_copy(g_infer_context.req_manager, is_prefill=True)
4346
batch_idx.extend([i for _ in range(best_of)])
4447
else:
@@ -47,46 +50,69 @@ def diverse_copy(self, groups: List[InferReqGroup]):
4750
return batch_idx, run_reqs
4851

4952
def prefill(self, reqs: List[Tuple]):
50-
req_ids = self._init_reqs(reqs)
51-
self.build_group(req_ids)
52-
group_reqs = [
53-
g_infer_context.requests_mapping[req_id]
54-
for req_id in req_ids
55-
if convert_sub_id_to_group_id(req_id) == req_id
56-
]
57-
groups = [
58-
g_infer_context.group_mapping[req_id] for req_id in req_ids if convert_sub_id_to_group_id(req_id) == req_id
59-
]
60-
kwargs, group_run_reqs = prepare_prefill_inputs(
61-
group_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
62-
)
63-
logits = self.model.forward(**kwargs)
64-
batch_idx, run_reqs = self.diverse_copy(groups)
65-
logits = logits[batch_idx]
66-
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
67-
next_token_ids = next_token_ids.detach().cpu().numpy()
68-
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
53+
req_ids = self._init_reqs(reqs, init_req_obj=False)
54+
# group_reqs = [
55+
# g_infer_context.requests_mapping[req_id]
56+
# for req_id in req_ids
57+
# if convert_sub_id_to_group_id(req_id) == req_id
58+
# ]
59+
# groups = [
60+
# g_infer_context.group_mapping[req_id] for req_id in req_ids if convert_sub_id_to_group_id(req_id) == req_id
61+
# ]
62+
# kwargs, group_run_reqs = prepare_prefill_inputs(
63+
# group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
64+
# )
65+
# logits = self.model.forward(**kwargs)
66+
# batch_idx, run_reqs = self.diverse_copy(groups)
67+
# logits = logits[batch_idx]
68+
# next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
69+
# next_token_ids = next_token_ids.detach().cpu().numpy()
70+
# next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
6971

70-
self._post_handle(
71-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
72-
)
72+
# self._post_handle(
73+
# run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
74+
# )
7375
return
7476

7577
def decode(self):
7678
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
7779
g_infer_context.infer_req_ids
7880
)
79-
assert len(uninit_reqs) == 0
80-
assert len(prefill_reqs) == 0
8181

8282
if aborted_reqs:
8383
g_infer_context.filter_reqs(aborted_reqs)
8484

85+
if prefill_reqs:
86+
group_reqs = [
87+
g_infer_context.requests_mapping[req.req_id]
88+
for req in prefill_reqs
89+
if convert_sub_id_to_group_id(req.req_id) == req.req_id
90+
]
91+
groups = [
92+
g_infer_context.group_mapping[req.req_id] for req in prefill_reqs if convert_sub_id_to_group_id(req.req_id) == req.req_id
93+
]
94+
kwargs, group_run_reqs = prepare_prefill_inputs(
95+
group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
96+
)
97+
logits = self.model.forward(**kwargs)
98+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False)
99+
self.build_group(uninit_reqs)
100+
batch_idx, run_reqs = self.diverse_copy(groups)
101+
logits = logits[batch_idx]
102+
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
103+
next_token_ids = next_token_ids.detach().cpu().numpy()
104+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
105+
106+
self._post_handle(
107+
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
108+
)
109+
85110
if decode_reqs:
86111
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
87112
logits = self.model.forward(**kwargs)
88113

89-
self._overlap_req_init_and_filter(uninit_reqs=[], ok_finished_reqs=ok_finished_reqs, clear_list=True)
114+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False)
115+
self.build_group(uninit_reqs)
90116

91117
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
92118
next_token_ids = next_token_ids.detach().cpu().numpy()
@@ -95,6 +121,8 @@ def decode(self):
95121
self._post_handle(
96122
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
97123
)
98-
99-
self._overlap_req_init_and_filter(uninit_reqs=[], ok_finished_reqs=ok_finished_reqs, clear_list=True)
124+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False)
125+
self.build_group(uninit_reqs)
126+
uninit_reqs.clear()
127+
ok_finished_reqs.clear()
100128
return

lightllm/server/router/req_queue/continues_batch/beam_impl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new
4646
# prefill token 计算
4747
for req in cur_handle_group_reqs:
4848
new_batch_first_router_need_tokens += req.shm_cur_output_len
49-
new_batch_first_router_need_tokens += req.input_len
50-
49+
new_batch_first_router_need_tokens += req.get_first_router_need_tokens()
5150
ok_token_num = (
5251
need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)
5352
< self.max_total_tokens

0 commit comments

Comments
 (0)