44 g_infer_context ,
55 InferReq ,
66 InferReqGroup ,
7- InferSamplingParams ,
87)
98from typing import List , Tuple
10- from lightllm .utils .log_utils import init_logger
11- from lightllm .server .tokenizer import get_tokenizer
129from lightllm .server .req_id_generator import convert_sub_id_to_group_id
1310from lightllm .server .router .model_infer .mode_backend .pre import (
1411 prepare_prefill_inputs ,
15- prepare_decode_inputs ,
1612)
1713from lightllm .server .router .model_infer .mode_backend .generic_post_process import sample
14+ from lightllm .server .router .model_infer .mode_backend .overlap_events import OverlapEventPack
1815
1916
2017class DiversehBackend (ModeBackend ):
2118 def __init__ (self ) -> None :
2219 super ().__init__ ()
23-
24- def init_custom (self ):
25- pass
26-
27- def build_group (self , req_ids : List [int ]):
28- for r_id in req_ids :
29- req : InferReq = g_infer_context .requests_mapping [r_id ]
30- group_req_id = req .shm_req .group_req_id
31- if group_req_id not in g_infer_context .group_mapping :
32- g_infer_context .group_mapping [group_req_id ] = InferReqGroup (group_req_id = group_req_id )
33- g_infer_context .group_mapping [group_req_id ].add_req (r_id )
20+ self .prefill = self .beam_prefill
3421
3522 def diverse_copy (self , groups : List [InferReqGroup ]):
3623 batch_idx = []
@@ -46,64 +33,36 @@ def diverse_copy(self, groups: List[InferReqGroup]):
4633 run_reqs .extend (req_group .get_all_reqs ())
4734 return batch_idx , run_reqs
4835
49- def decode (self ):
50- uninit_reqs , aborted_reqs , ok_finished_reqs , prefill_reqs , decode_reqs = self ._get_classed_reqs (
51- g_infer_context .infer_req_ids ,
52- strict_prefill = True ,
36+ def beam_prefill (self , event_pack : OverlapEventPack , prefill_reqs : List [InferReq ]):
37+ group_reqs = [
38+ g_infer_context .requests_mapping [req .req_id ]
39+ for req in prefill_reqs
40+ if convert_sub_id_to_group_id (req .req_id ) == req .req_id
41+ ]
42+ groups = [
43+ g_infer_context .group_mapping [req .req_id ]
44+ for req in prefill_reqs
45+ if convert_sub_id_to_group_id (req .req_id ) == req .req_id
46+ ]
47+ model_input , group_run_reqs = prepare_prefill_inputs (
48+ group_reqs , is_chuncked_mode = not self .disable_chunked_prefill , is_multimodal = self .is_multimodal
5349 )
50+ model_output = self .model .forward (model_input )
51+ logits = model_output .logits
5452
55- if aborted_reqs :
56- g_infer_context .filter_reqs (aborted_reqs )
57- if prefill_reqs :
58- group_reqs = [
59- g_infer_context .requests_mapping [req .req_id ]
60- for req in prefill_reqs
61- if convert_sub_id_to_group_id (req .req_id ) == req .req_id
62- ]
63- groups = [
64- g_infer_context .group_mapping [req .req_id ]
65- for req in prefill_reqs
66- if convert_sub_id_to_group_id (req .req_id ) == req .req_id
67- ]
68- model_input , group_run_reqs = prepare_prefill_inputs (
69- group_reqs , is_chuncked_mode = True , is_multimodal = self .is_multimodal
70- )
71- model_output = self .model .forward (model_input )
72- logits = model_output .logits
73-
74- uninit_req_ids = [req .req_id for req in uninit_reqs ]
75- self ._overlap_req_init_and_filter (
76- uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True
77- )
78- self .build_group (uninit_req_ids )
79- batch_idx , run_reqs = self .diverse_copy (groups )
80- logits = logits [batch_idx ]
81- next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
82- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
83- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
53+ batch_idx , run_reqs = self .diverse_copy (groups )
54+ logits = logits [batch_idx ]
8455
85- self ._post_handle (
86- run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = True , do_filter_finished_reqs = False
87- )
56+ next_token_ids_gpu , next_token_probs_gpu = sample ( model_output . logits , run_reqs , self .eos_id )
57+ next_token_ids_cpu = next_token_ids_gpu . detach (). cpu (). numpy ()
58+ next_token_logprobs_cpu = torch . log ( next_token_probs_gpu ). detach (). cpu (). numpy ( )
8859
89- if decode_reqs :
90- model_input , run_reqs = prepare_decode_inputs (decode_reqs )
91- model_output = self .model .forward (model_input )
92- logits = model_output .logits
93- uninit_req_ids = [req .req_id for req in uninit_reqs ]
94- self ._overlap_req_init_and_filter (
95- uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True
96- )
97- self .build_group (uninit_req_ids )
98-
99- next_token_ids , next_token_probs = sample (logits , run_reqs , self .eos_id )
100- next_token_ids = next_token_ids .detach ().cpu ().numpy ()
101- next_token_logprobs = torch .log (next_token_probs ).detach ().cpu ().numpy ()
102-
103- self ._post_handle (
104- run_reqs , next_token_ids , next_token_logprobs , is_chuncked_mode = False , do_filter_finished_reqs = False
105- )
106- uninit_req_ids = [req .req_id for req in uninit_reqs ]
107- self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
108- self .build_group (uninit_req_ids )
60+ update_packs = self ._pre_post_handle (run_reqs , is_chuncked_mode = not self .disable_chunked_prefill )
61+ self ._post_handle (
62+ run_reqs = run_reqs ,
63+ next_token_ids = next_token_ids_cpu ,
64+ next_token_logprobs = next_token_logprobs_cpu ,
65+ run_reqs_update_packs = update_packs ,
66+ extra_post_req_handle_func = self .extra_post_req_handle_func ,
67+ )
10968 return
0 commit comments