Skip to content

Commit 716efc7

Browse files
committed
reformat
1 parent aa85395 commit 716efc7

File tree

2 files changed

+12
-30
lines changed

2 files changed

+12
-30
lines changed

lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@ def diverse_copy(self, groups: List[InferReqGroup]):
3838
req_group = groups[i]
3939
best_of = req_group.best_of()
4040
_0_req_obj = req_group.get_req(0)
41-
if (
42-
best_of > 1 and
43-
_0_req_obj.get_chuncked_input_token_len() == _0_req_obj.get_cur_total_len()
44-
):
41+
if best_of > 1 and _0_req_obj.get_chuncked_input_token_len() == _0_req_obj.get_cur_total_len():
4542
req_group.diverse_copy(g_infer_context.req_manager, is_prefill=True)
4643
batch_idx.extend([i for _ in range(best_of)])
4744
else:
@@ -50,28 +47,7 @@ def diverse_copy(self, groups: List[InferReqGroup]):
5047
return batch_idx, run_reqs
5148

5249
def prefill(self, reqs: List[Tuple]):
53-
req_ids = self._init_reqs(reqs, init_req_obj=False)
54-
# group_reqs = [
55-
# g_infer_context.requests_mapping[req_id]
56-
# for req_id in req_ids
57-
# if convert_sub_id_to_group_id(req_id) == req_id
58-
# ]
59-
# groups = [
60-
# g_infer_context.group_mapping[req_id] for req_id in req_ids if convert_sub_id_to_group_id(req_id) == req_id
61-
# ]
62-
# kwargs, group_run_reqs = prepare_prefill_inputs(
63-
# group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
64-
# )
65-
# logits = self.model.forward(**kwargs)
66-
# batch_idx, run_reqs = self.diverse_copy(groups)
67-
# logits = logits[batch_idx]
68-
# next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
69-
# next_token_ids = next_token_ids.detach().cpu().numpy()
70-
# next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
71-
72-
# self._post_handle(
73-
# run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
74-
# )
50+
self._init_reqs(reqs, init_req_obj=False)
7551
return
7652

7753
def decode(self):
@@ -89,13 +65,17 @@ def decode(self):
8965
if convert_sub_id_to_group_id(req.req_id) == req.req_id
9066
]
9167
groups = [
92-
g_infer_context.group_mapping[req.req_id] for req in prefill_reqs if convert_sub_id_to_group_id(req.req_id) == req.req_id
68+
g_infer_context.group_mapping[req.req_id]
69+
for req in prefill_reqs
70+
if convert_sub_id_to_group_id(req.req_id) == req.req_id
9371
]
9472
kwargs, group_run_reqs = prepare_prefill_inputs(
9573
group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
9674
)
9775
logits = self.model.forward(**kwargs)
98-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False)
76+
self._overlap_req_init_and_filter(
77+
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False
78+
)
9979
self.build_group(uninit_reqs)
10080
batch_idx, run_reqs = self.diverse_copy(groups)
10181
logits = logits[batch_idx]
@@ -111,7 +91,9 @@ def decode(self):
11191
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
11292
logits = self.model.forward(**kwargs)
11393

114-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False)
94+
self._overlap_req_init_and_filter(
95+
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False
96+
)
11597
self.build_group(uninit_reqs)
11698

11799
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
@@ -121,6 +103,7 @@ def decode(self):
121103
self._post_handle(
122104
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
123105
)
106+
124107
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=False)
125108
self.build_group(uninit_reqs)
126109
uninit_reqs.clear()

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ torchvision==0.20.1
6565
tqdm==4.65.0
6666
transformers==4.51.2
6767
tokenizers==0.21.0
68-
huggingface-hub==0.26.5
6968
triton==3.1.0
7069
urllib3==1.26.16
7170
uvicorn==0.19.0

0 commit comments

Comments
 (0)