44from lightllm .server .router .model_infer .infer_batch import InferReq , g_infer_context
55from lightllm .common .basemodel .infer_lock import g_infer_state_lock
66from lightllm .common .basemodel .batch_objs import ModelInput
7+ from lightllm .utils .envs_utils import get_env_start_args , get_diverse_max_batch_shared_group_size
78
89
910def prepare_prefill_inputs (
@@ -99,12 +100,16 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
99100 b_mtp_index = []
100101 b_seq_len = []
101102 b_q_seq_len = []
103+ b_shared_seq_len = []
104+ max_batch_shared_group_size = get_diverse_max_batch_shared_group_size ()
102105 for req in req_objs :
106+ _radix_shared_len = req .get_radix_cache_shared_len ()
103107 run_reqs .append (req )
104108 b_req_idx .append (req .req_idx )
105109 seq_len = req .get_cur_total_len ()
106110 assert req .cur_kv_len == seq_len - 1 , f"{ req .cur_kv_len } { seq_len } "
107111 b_seq_len .append (seq_len )
112+ b_shared_seq_len .append (_radix_shared_len )
108113 total_token_num += seq_len
109114 max_len_in_batch = max (max_len_in_batch , seq_len )
110115 b_mtp_index .append (0 )
@@ -114,6 +119,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
114119 b_req_idx .append (req .req_idx )
115120 seq_len += 1
116121 b_seq_len .append (seq_len )
122+ b_shared_seq_len .append (_radix_shared_len )
117123 total_token_num += seq_len
118124 max_len_in_batch = max (max_len_in_batch , seq_len )
119125 b_mtp_index .append (step + 1 )
@@ -124,7 +130,36 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
124130
125131 b_req_idx = torch .tensor (b_req_idx , dtype = torch .int32 , device = "cpu" )
126132 b_seq_len = torch .tensor (b_seq_len , dtype = torch .int32 , device = "cpu" )
133+ b_shared_seq_len = torch .tensor (b_shared_seq_len , dtype = torch .int32 , device = "cpu" )
127134 b_mtp_index = torch .tensor (b_mtp_index , dtype = torch .int32 , device = "cpu" )
135+ if get_env_start_args ().diverse_mode :
136+ b_mark_shared_group = []
137+ shared_nodes = [req .shared_kv_node for req in run_reqs ]
138+ _current_group = []
139+ for node in shared_nodes :
140+ if not _current_group :
141+ _current_group .append (node )
142+ elif node == _current_group [- 1 ]:
143+ _current_group .append (node )
144+ else :
145+ b_mark_shared_group .extend ([0 for _ in range (len (_current_group ))])
146+ b_mark_shared_group [- 1 ] = len (_current_group )
147+ _current_group .clear ()
148+ _current_group .append (node )
149+
150+ if len (_current_group ) == max_batch_shared_group_size :
151+ b_mark_shared_group .extend ([0 for _ in range (len (_current_group ))])
152+ b_mark_shared_group [- 1 ] = len (_current_group )
153+ _current_group .clear ()
154+ if _current_group :
155+ b_mark_shared_group .extend ([0 for _ in range (len (_current_group ))])
156+ b_mark_shared_group [- 1 ] = len (_current_group )
157+ _current_group .clear ()
158+
159+ assert len (b_mark_shared_group ) == len (run_reqs )
160+ b_mark_shared_group = torch .tensor (b_mark_shared_group , dtype = torch .int32 , device = "cpu" )
161+ else :
162+ b_mark_shared_group = None
128163
129164 # dynamic prompt cache 准备 token
130165 g_infer_state_lock .acquire ()
@@ -144,6 +179,8 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
144179 b_req_idx = b_req_idx ,
145180 b_mtp_index = b_mtp_index ,
146181 b_seq_len = b_seq_len ,
182+ b_shared_seq_len = b_shared_seq_len ,
183+ b_mark_shared_group = b_mark_shared_group ,
147184 is_prefill = False ,
148185 )
149186 return model_input , run_reqs
0 commit comments