@@ -38,10 +38,7 @@ def diverse_copy(self, groups: List[InferReqGroup]):
3838 req_group = groups [i ]
3939 best_of = req_group .best_of ()
4040 _0_req_obj = req_group .get_req (0 )
41- if (
42- best_of > 1 and
43- _0_req_obj .get_chuncked_input_token_len () == _0_req_obj .get_cur_total_len ()
44- ):
41+ if best_of > 1 and _0_req_obj .get_chuncked_input_token_len () == _0_req_obj .get_cur_total_len ():
4542 req_group .diverse_copy (g_infer_context .req_manager , is_prefill = True )
4643 batch_idx .extend ([i for _ in range (best_of )])
4744 else :
@@ -50,28 +47,7 @@ def diverse_copy(self, groups: List[InferReqGroup]):
5047 return batch_idx , run_reqs
5148
5249 def prefill (self , reqs : List [Tuple ]):
53- req_ids = self ._init_reqs (reqs , init_req_obj = False )
54- # group_reqs = [
55- # g_infer_context.requests_mapping[req_id]
56- # for req_id in req_ids
57- # if convert_sub_id_to_group_id(req_id) == req_id
58- # ]
59- # groups = [
60- # g_infer_context.group_mapping[req_id] for req_id in req_ids if convert_sub_id_to_group_id(req_id) == req_id
61- # ]
62- # kwargs, group_run_reqs = prepare_prefill_inputs(
63- # group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
64- # )
65- # logits = self.model.forward(**kwargs)
66- # batch_idx, run_reqs = self.diverse_copy(groups)
67- # logits = logits[batch_idx]
68- # next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
69- # next_token_ids = next_token_ids.detach().cpu().numpy()
70- # next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
71-
72- # self._post_handle(
73- # run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
74- # )
50+ self ._init_reqs (reqs , init_req_obj = False )
7551 return
7652
7753 def decode (self ):
@@ -89,13 +65,17 @@ def decode(self):
8965 if convert_sub_id_to_group_id (req .req_id ) == req .req_id
9066 ]
9167 groups = [
92- g_infer_context .group_mapping [req .req_id ] for req in prefill_reqs if convert_sub_id_to_group_id (req .req_id ) == req .req_id
68+ g_infer_context .group_mapping [req .req_id ]
69+ for req in prefill_reqs
70+ if convert_sub_id_to_group_id (req .req_id ) == req .req_id
9371 ]
9472 kwargs , group_run_reqs = prepare_prefill_inputs (
9573 group_reqs , is_chuncked_mode = True , is_multimodal = self .is_multimodal
9674 )
9775 logits = self .model .forward (** kwargs )
98- self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False )
76+ self ._overlap_req_init_and_filter (
77+ uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False
78+ )
9979 self .build_group (uninit_reqs )
10080 batch_idx , run_reqs = self .diverse_copy (groups )
10181 logits = logits [batch_idx ]
@@ -111,7 +91,9 @@ def decode(self):
11191 kwargs , run_reqs = prepare_decode_inputs (decode_reqs )
11292 logits = self .model .forward (** kwargs )
11393
114- self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False )
94+ self ._overlap_req_init_and_filter (
95+ uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False
96+ )
11597 self .build_group (uninit_reqs )
11698
11799 next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
@@ -121,6 +103,7 @@ def decode(self):
121103 self ._post_handle (
122104 run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = False , do_filter_finished_reqs = False
123105 )
106+
124107 self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False )
125108 self .build_group (uninit_reqs )
126109 uninit_reqs .clear ()
0 commit comments