22import numpy as np
33from ...batch import Batch , Req
44from lightllm .server .router .req_queue .base_queue import BaseQueue
5+ from lightllm .common .basemodel .infer_lock import g_router_lock
56
67
78class ChunkedPrefillQueue (BaseQueue ):
@@ -25,31 +26,32 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens
2526 self .cache_len_list .sort (key = lambda x : - x [1 ])
2627
2728 left_out_len_array = np .array ([e [1 ] for e in self .cache_len_list ])
28- # assert left_out_len_array.min() >= 0
2929 has_run_len_array = np .array ([e [0 ] for e in self .cache_len_list ])
3030 cum_run_len_array = np .cumsum (has_run_len_array )
3131 size_array = np .arange (1 , len (self .cache_len_list ) + 1 , 1 )
3232
3333 need_max_token_num = (left_out_len_array * size_array + cum_run_len_array ).max ()
34- ok_token_num = (
35- need_max_token_num + self .router .shared_token_load .get_frozened_token_count (self .dp_index )
36- < self .max_total_tokens
37- )
34+ with g_router_lock .obj :
35+ ok_token_num = (
36+ need_max_token_num + self .router .shared_token_load .get_frozened_token_count (self .dp_index )
37+ < self .max_total_tokens
38+ )
3839
39- ok_req_num = len (self .cache_len_list ) <= self .running_max_req_size
40- new_batch_first_router_need_tokens += req .get_first_router_need_tokens ()
41- ok_prefill = new_batch_first_router_need_tokens <= self .batch_max_tokens
40+ ok_req_num = len (self .cache_len_list ) <= self .running_max_req_size
4241
43- if ok_token_num and ok_req_num and ok_prefill :
44- self .router .shared_token_load .set_estimated_peak_token_count (need_max_token_num , self .dp_index )
45- self .router .shared_token_load .set_dynamic_max_load (
46- (need_max_token_num + self .router .shared_token_load .get_frozened_token_count (self .dp_index ))
47- / self .max_total_tokens ,
48- self .dp_index ,
49- )
50- return True , new_batch_first_router_need_tokens
51- else :
52- return False , new_batch_first_router_need_tokens
42+ new_batch_first_router_need_tokens += req .get_first_router_need_tokens ()
43+ ok_prefill = new_batch_first_router_need_tokens <= self .batch_max_tokens
44+
45+ if ok_token_num and ok_req_num and ok_prefill :
46+ self .router .shared_token_load .set_estimated_peak_token_count (need_max_token_num , self .dp_index )
47+ self .router .shared_token_load .set_dynamic_max_load (
48+ (need_max_token_num + self .router .shared_token_load .get_frozened_token_count (self .dp_index ))
49+ / self .max_total_tokens ,
50+ self .dp_index ,
51+ )
52+ return True , new_batch_first_router_need_tokens
53+ else :
54+ return False , new_batch_first_router_need_tokens
5355
5456 # @calculate_time(show=True, min_cost_ms=10)
5557 def generate_new_batch (self , current_batch : Batch , limit_router_queue_length : int = None ):
@@ -114,8 +116,10 @@ def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch):
114116 need_max_token_num = (left_out_len_array * size_array + cum_run_len_array ).max ()
115117 else :
116118 need_max_token_num = 0
117- return (
118- need_max_token_num ,
119- (need_max_token_num + self .router .shared_token_load .get_frozened_token_count (self .dp_index ))
120- / self .max_total_tokens ,
121- )
119+
120+ with g_router_lock .obj :
121+ return (
122+ need_max_token_num ,
123+ (need_max_token_num + self .router .shared_token_load .get_frozened_token_count (self .dp_index ))
124+ / self .max_total_tokens ,
125+ )
0 commit comments