@@ -24,21 +24,24 @@ def __init__(self) -> None:
2424 def init_custom (self ):
2525 pass
2626
27- def build_group (self , req_ids : List [int ]):
28- for r_id in req_ids :
29- req : InferReq = g_infer_context .requests_mapping [r_id ]
27+ def build_group (self , reqs : List [InferReq ]):
28+ for req in reqs :
3029 group_req_id = req .shm_req .group_req_id
3130 if group_req_id not in g_infer_context .group_mapping :
3231 g_infer_context .group_mapping [group_req_id ] = InferReqGroup (group_req_id = group_req_id )
33- g_infer_context .group_mapping [group_req_id ].add_req (r_id )
32+ g_infer_context .group_mapping [group_req_id ].add_req (req . req_id )
3433
3534 def diverse_copy (self , groups : List [InferReqGroup ]):
3635 batch_idx = []
3736 run_reqs = []
3837 for i in range (len (groups )):
3938 req_group = groups [i ]
4039 best_of = req_group .best_of ()
41- if best_of > 1 :
40+ _0_req_obj = req_group .get_req (0 )
41+ if (
42+ best_of > 1 and
43+ _0_req_obj .get_chuncked_input_token_len () == _0_req_obj .get_cur_total_len ()
44+ ):
4245 req_group .diverse_copy (g_infer_context .req_manager , is_prefill = True )
4346 batch_idx .extend ([i for _ in range (best_of )])
4447 else :
@@ -47,46 +50,69 @@ def diverse_copy(self, groups: List[InferReqGroup]):
4750 return batch_idx , run_reqs
4851
4952 def prefill (self , reqs : List [Tuple ]):
50- req_ids = self ._init_reqs (reqs )
51- self .build_group (req_ids )
52- group_reqs = [
53- g_infer_context .requests_mapping [req_id ]
54- for req_id in req_ids
55- if convert_sub_id_to_group_id (req_id ) == req_id
56- ]
57- groups = [
58- g_infer_context .group_mapping [req_id ] for req_id in req_ids if convert_sub_id_to_group_id (req_id ) == req_id
59- ]
60- kwargs , group_run_reqs = prepare_prefill_inputs (
61- group_reqs , is_chuncked_mode = False , is_multimodal = self .is_multimodal
62- )
63- logits = self .model .forward (** kwargs )
64- batch_idx , run_reqs = self .diverse_copy (groups )
65- logits = logits [batch_idx ]
66- next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
67- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
68- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
53+ req_ids = self ._init_reqs (reqs , init_req_obj = False )
54+ # group_reqs = [
55+ # g_infer_context.requests_mapping[req_id]
56+ # for req_id in req_ids
57+ # if convert_sub_id_to_group_id(req_id) == req_id
58+ # ]
59+ # groups = [
60+ # g_infer_context.group_mapping[req_id] for req_id in req_ids if convert_sub_id_to_group_id(req_id) == req_id
61+ # ]
62+ # kwargs, group_run_reqs = prepare_prefill_inputs(
63+ # group_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
64+ # )
65+ # logits = self.model.forward(**kwargs)
66+ # batch_idx, run_reqs = self.diverse_copy(groups)
67+ # logits = logits[batch_idx]
68+ # next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
69+ # next_token_ids = next_token_ids.detach().cpu().numpy()
70+ # next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
6971
70- self ._post_handle (
71- run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = False , do_filter_finished_reqs = False
72- )
72+ # self._post_handle(
73+ # run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True , do_filter_finished_reqs=False
74+ # )
7375 return
7476
7577 def decode (self ):
7678 uninit_reqs , aborted_reqs , ok_finished_reqs , prefill_reqs , decode_reqs = self ._get_classed_reqs (
7779 g_infer_context .infer_req_ids
7880 )
79- assert len (uninit_reqs ) == 0
80- assert len (prefill_reqs ) == 0
8181
8282 if aborted_reqs :
8383 g_infer_context .filter_reqs (aborted_reqs )
8484
85+ if prefill_reqs :
86+ group_reqs = [
87+ g_infer_context .requests_mapping [req .req_id ]
88+ for req in prefill_reqs
89+ if convert_sub_id_to_group_id (req .req_id ) == req .req_id
90+ ]
91+ groups = [
92+ g_infer_context .group_mapping [req .req_id ] for req in prefill_reqs if convert_sub_id_to_group_id (req .req_id ) == req .req_id
93+ ]
94+ kwargs , group_run_reqs = prepare_prefill_inputs (
95+ group_reqs , is_chuncked_mode = True , is_multimodal = self .is_multimodal
96+ )
97+ logits = self .model .forward (** kwargs )
98+ self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False )
99+ self .build_group (uninit_reqs )
100+ batch_idx , run_reqs = self .diverse_copy (groups )
101+ logits = logits [batch_idx ]
102+ next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
103+ next_token_ids = next_token_ids .detach ().cpu ().numpy ()
104+ next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
105+
106+ self ._post_handle (
107+ run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = True , do_filter_finished_reqs = False
108+ )
109+
85110 if decode_reqs :
86111 kwargs , run_reqs = prepare_decode_inputs (decode_reqs )
87112 logits = self .model .forward (** kwargs )
88113
89- self ._overlap_req_init_and_filter (uninit_reqs = [], ok_finished_reqs = ok_finished_reqs , clear_list = True )
114+ self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False )
115+ self .build_group (uninit_reqs )
90116
91117 next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
92118 next_token_ids = next_token_ids .detach ().cpu ().numpy ()
@@ -95,6 +121,8 @@ def decode(self):
95121 self ._post_handle (
96122 run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = False , do_filter_finished_reqs = False
97123 )
98-
99- self ._overlap_req_init_and_filter (uninit_reqs = [], ok_finished_reqs = ok_finished_reqs , clear_list = True )
124+ self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = False )
125+ self .build_group (uninit_reqs )
126+ uninit_reqs .clear ()
127+ ok_finished_reqs .clear ()
100128 return
0 commit comments