1+ import random
12from typing import List , Union
23from lightllm .server .router .req_queue .base_queue import BaseQueue
34from lightllm .server .router .batch import Batch , Req
@@ -21,37 +22,22 @@ def assign_reqs_to_dp(self, current_batch: Batch, reqs_waiting_for_dp_index: Lis
2122 if len (reqs_waiting_for_dp_index ) == 0 :
2223 return
2324 # calculate the total load of each dp rank
25+ all_dp_req_num = [0 for _ in range (self .dp_size_in_node )]
2426 if current_batch is not None :
2527 all_dp_req_num = current_batch .get_all_dp_req_num ()
26- total_load_per_dp = [
27- all_dp_req_num [i ] + len (self .inner_queues [i ].waiting_req_list ) for i in range (self .dp_size_in_node )
28- ]
29- else :
30- total_load_per_dp = [len (self .inner_queues [i ].waiting_req_list ) for i in range (self .dp_size_in_node )]
28+ total_load_per_dp = [
29+ all_dp_req_num [i ] + len (self .inner_queues [i ].waiting_req_list ) for i in range (self .dp_size_in_node )
30+ ]
3131 for req_group in reqs_waiting_for_dp_index :
3232 # calculate the length of this request group
33- if isinstance (req_group , list ):
34- req_length = len (req_group )
35- else :
36- req_length = 1
33+ req_length = len (req_group )
3734
3835 # find the dp rank with minimum load
3936 min_load = min (total_load_per_dp )
4037 select_dp_indexes = [i for i in range (self .dp_size_in_node ) if total_load_per_dp [i ] == min_load ]
41-
42- # select the dp rank with the minimum load
43- if len (select_dp_indexes ) == 1 :
44- suggested_dp_index = select_dp_indexes [0 ]
45- else :
46- # if multiple dp ranks have the same minimum load, randomly select one
47- import random
48-
49- suggested_dp_index = random .choice (select_dp_indexes )
38+ suggested_dp_index = random .choice (select_dp_indexes )
5039
5140 # assign the request to the dp rank and update the load count
52- if not isinstance (req_group , list ):
53- req_group = [req_group ]
54-
5541 for req in req_group :
5642 req .sample_params .suggested_dp_index = suggested_dp_index
5743 self .inner_queues [suggested_dp_index ].append (req )
0 commit comments