11import os
22import shutil
33import torch
4- from .impl import ContinuesBatchBackend
5- from lightllm .utils .infer_utils import calculate_time , mark_start , mark_end
4+ from .impl import ChunkedPrefillBackend
65from lightllm .server .core .objs import FinishStatus
7- from lightllm .server .router .model_infer .infer_batch import g_infer_context , InferReq , InferSamplingParams
6+ from lightllm .server .router .model_infer .infer_batch import g_infer_context , InferReq
87from lightllm .server .router .model_infer .mode_backend .generic_pre_process import (
98 prepare_prefill_inputs ,
109 prepare_decode_inputs ,
1716logger = init_logger (__name__ )
1817
1918
20- class OutlinesConstraintBackend (ContinuesBatchBackend ):
19+ class OutlinesConstraintBackend (ChunkedPrefillBackend ):
2120 def __init__ (self ) -> None :
2221 super ().__init__ ()
2322
@@ -45,63 +44,23 @@ def init_custom(self):
4544 logger .info (f"eos_ids { self .tokenizer .eos_token_ids } " )
4645 return
4746
48- def prefill (self , reqs : List [Tuple ]):
49-
50- req_ids = self ._init_reqs (reqs )
51-
52- # import here, 当你不使用这个模式,缺少这些依赖也可以运行
53- from outlines .fsm .guide import RegexGuide
54-
55- req_objs = self ._trans_req_ids_to_req_objs (req_ids )
56- kwargs , run_reqs = prepare_prefill_inputs (req_objs , is_chuncked_mode = False , is_multimodal = self .is_multimodal )
57- run_reqs : List [InferReq ] = run_reqs
58-
59- logics = self .model .forward (** kwargs )
60-
61- # 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0
62- mask = torch .ones_like (logics , dtype = torch .bool )
63- for i , run_obj in enumerate (run_reqs ):
64- run_obj : InferReq = run_obj
65- sample_params = run_obj .sampling_param
66- if sample_params .regular_constraint is not None :
67- sample_params .regex_guide = RegexGuide .from_regex (sample_params .regular_constraint , self .tokenizer )
68- self ._mask_req_out_token (i , run_obj , mask )
69-
70- logics [mask ] = - 1000000.0
71-
72- next_token_ids , next_token_probs = sample (logics , run_reqs , self .eos_id )
73- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
74- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
75-
76- self ._post_handle (
77- run_reqs ,
78- next_token_ids ,
79- next_token_logprobs ,
80- is_chuncked_mode = False ,
81- do_filter_finished_reqs = False ,
82- extra_post_req_handle_func = self ._update_state_fsm ,
83- )
84-
85- return
86-
8747 def decode (self ):
8848 uninit_reqs , aborted_reqs , ok_finished_reqs , prefill_reqs , decode_reqs = self ._get_classed_reqs (
8949 g_infer_context .infer_req_ids
9050 )
91- assert len (uninit_reqs ) == 0
92- assert len (prefill_reqs ) == 0
9351
9452 if aborted_reqs :
9553 g_infer_context .filter_reqs (aborted_reqs )
9654
55+ # 先 decode
9756 if decode_reqs :
9857 kwargs , run_reqs = prepare_decode_inputs (decode_reqs )
99- run_reqs : List [InferReq ] = run_reqs
100-
10158 logits = self .model .forward (** kwargs )
59+ self ._overlap_req_init_and_filter (
60+ uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True
61+ )
10262
103- self ._overlap_req_init_and_filter (uninit_reqs = [], ok_finished_reqs = ok_finished_reqs , clear_list = True )
104-
63+ self ._init_guide_infos (run_reqs )
10564 all_has_no_constraint = all ([not e .sampling_param .has_constraint_setting () for e in run_reqs ])
10665 if not all_has_no_constraint :
10766 mask = torch .ones_like (logits , dtype = torch .bool )
@@ -112,7 +71,6 @@ def decode(self):
11271 next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
11372 next_token_ids = next_token_ids .detach ().cpu ().numpy ()
11473 next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
115-
11674 self ._post_handle (
11775 run_reqs ,
11876 next_token_ids ,
@@ -121,8 +79,42 @@ def decode(self):
12179 do_filter_finished_reqs = False ,
12280 extra_post_req_handle_func = self ._update_state_fsm ,
12381 )
82+ logits = None
83+
84+ # 再 prefill
85+ if len (decode_reqs ) == 0 or (self .forward_step % self .max_wait_step == 0 ) or (self .need_prefill_count > 0 ):
86+ if prefill_reqs :
87+ self .need_prefill_count -= 1
88+ kwargs , run_reqs = prepare_prefill_inputs (
89+ prefill_reqs , is_chuncked_mode = True , is_multimodal = self .is_multimodal
90+ )
91+ logits = self .model .forward (** kwargs )
92+ self ._overlap_req_init_and_filter (
93+ uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True
94+ )
95+ # 对于不能满足前缀匹配的logic位置,将其logics设置为一个较大负值,将其概率掩盖为 0
96+ self ._init_guide_infos (run_reqs )
97+ mask = torch .ones_like (logits , dtype = torch .bool )
98+ for i , run_obj in enumerate (run_reqs ):
99+ self ._mask_req_out_token (i , run_obj , mask )
124100
125- self ._overlap_req_init_and_filter (uninit_reqs = [], ok_finished_reqs = ok_finished_reqs , clear_list = True )
101+ logits [mask ] = - 1000000.0
102+
103+ next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
104+ next_token_ids = next_token_ids .detach ().cpu ().numpy ()
105+ next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
106+ self ._post_handle (
107+ run_reqs ,
108+ next_token_ids ,
109+ next_token_logprobs ,
110+ is_chuncked_mode = True ,
111+ do_filter_finished_reqs = False ,
112+ extra_post_req_handle_func = self ._update_state_fsm ,
113+ )
114+ logits = None
115+
116+ self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
117+ self .forward_step += 1
126118 return
127119
128120 def _update_state_fsm (self , req_obj : InferReq , next_token_id , next_token_logprob ):
@@ -138,13 +130,28 @@ def _update_state_fsm(self, req_obj: InferReq, next_token_id, next_token_logprob
138130 def _mask_req_out_token (self , i , run_obj : InferReq , mask ):
139131 from outlines .fsm .guide import RegexGuide
140132
141- sample_params = run_obj .sampling_param
142- if sample_params .regular_constraint is not None :
143- regex_guide : RegexGuide = sample_params .regex_guide
144- ok_token_id_list = regex_guide .get_next_instruction (sample_params .fsm_current_state ).tokens
145- mask [i , ok_token_id_list ] = False
146- elif sample_params .allowed_token_ids is not None :
147- mask [i , sample_params .allowed_token_ids ] = False
133+ if run_obj .get_chuncked_input_token_len () == run_obj .get_cur_total_len ():
134+ # this run_obj is ready to gen next token.
135+ sample_params = run_obj .sampling_param
136+ if sample_params .regular_constraint is not None :
137+ regex_guide : RegexGuide = sample_params .regex_guide
138+ ok_token_id_list = regex_guide .get_next_instruction (sample_params .fsm_current_state ).tokens
139+ mask [i , ok_token_id_list ] = False
140+ elif sample_params .allowed_token_ids is not None :
141+ mask [i , sample_params .allowed_token_ids ] = False
142+ else :
143+ mask [i , :] = False
148144 else :
145+ # no constraint
149146 mask [i , :] = False
150147 return
148+
149+ def _init_guide_infos (self , run_reqs : List [InferReq ]):
150+ from outlines .fsm .guide import RegexGuide
151+
152+ for i , run_obj in enumerate (run_reqs ):
153+ run_obj : InferReq = run_obj
154+ sample_params = run_obj .sampling_param
155+ if sample_params .regular_constraint is not None :
156+ if not hasattr (sample_params , "regex_guide" ):
157+ sample_params .regex_guide = RegexGuide .from_regex (sample_params .regular_constraint , self .tokenizer )
0 commit comments