28
28
import traceback
29
29
import uuid
30
30
import weakref
31
+ from collections import deque
31
32
from concurrent .futures import ThreadPoolExecutor
32
33
from typing import Dict , List , Optional , Tuple
33
34
@@ -125,9 +126,17 @@ def __init__(self, cfg):
125
126
cfg .max_num_seqs , cfg , cfg .tensor_parallel_size , cfg .splitwise_role
126
127
)
127
128
128
- os .environ ["INFERENCE_MSG_QUEUE_ID" ] = str (self .cfg .engine_worker_queue_port )
129
-
130
- self .split_connector = SplitwiseConnector (cfg , self .scheduler , self .engine_worker_queue , self .resource_manager )
129
+ os .environ ["INFERENCE_MSG_QUEUE_ID" ] = str (
130
+ self .cfg .engine_worker_queue_port + self .cfg .parallel_config .local_data_parallel_id
131
+ )
132
+ self .splitwise_queue = deque ()
133
+ self .split_connector = SplitwiseConnector (
134
+ cfg ,
135
+ self .scheduler ,
136
+ self .engine_worker_queue ,
137
+ self .resource_manager ,
138
+ self .splitwise_queue ,
139
+ )
131
140
132
141
self .token_processor = TokenProcessor (
133
142
cfg = self .cfg ,
@@ -343,12 +352,6 @@ def _insert_task_to_worker(self):
343
352
if self .cfg .splitwise_role == "mixed" or self .split_connector .has_splitwise_tasks ():
344
353
time .sleep (0.005 )
345
354
continue
346
- if self .engine_worker_queue .num_cache_infos () > 0 :
347
- time .sleep (0.001 )
348
- continue
349
- if len (self .split_connector .current_request_ids ) > 0 :
350
- time .sleep (0.001 )
351
- continue
352
355
353
356
num_prefill_batch = min (
354
357
int (self .resource_manager .available_batch ()),
@@ -596,43 +599,42 @@ def receiver_loop():
596
599
for idx in sorted (processed_indices , reverse = True ):
597
600
self .waiting_requests .pop (idx )
598
601
599
- if not self .engine_worker_queue .disaggregate_queue_empty ():
600
- items = self .engine_worker_queue .get_disaggregated_tasks ()
601
- for item in items :
602
- role = item [0 ]
603
- tasks = item [1 ]
602
+ if len (self .splitwise_queue ) > 0 :
603
+ items = self .splitwise_queue .pop ()
604
+ role = items [0 ]
605
+ tasks = items [1 ]
604
606
605
- if role == "prefill" :
606
- for task in tasks :
607
- task .max_tokens = task .min_tokens = 2
608
- self .insert_tasks (tasks )
607
+ if role == "prefill" :
608
+ for task in tasks :
609
+ task .max_tokens = task .min_tokens = 2
610
+ self .insert_tasks (tasks )
609
611
610
- elif role == "decode" :
611
- if hasattr (tasks [0 ], "finished" ):
612
- if not isinstance (tasks , list ):
613
- tasks = [tasks ]
614
- for task in tasks :
615
- task .finished = False
616
- self .insert_tasks (tasks , allocated = True )
612
+ elif role == "decode" :
613
+ if hasattr (tasks [0 ], "finished" ):
614
+ if not isinstance (tasks , list ):
615
+ tasks = [tasks ]
616
+ for task in tasks :
617
+ task .finished = False
618
+ self .insert_tasks (tasks , allocated = True )
617
619
618
- if self .cfg .innode_prefill_ports is not None :
619
- self .scheduler .put_results (tasks )
620
+ if self .cfg .innode_prefill_ports is not None :
621
+ self .scheduler .put_results (tasks )
620
622
623
+ else :
624
+ if len (self .waiting_requests ):
625
+ llm_logger .info (f"Waiting for resource for task { tasks [0 ].request_id } " )
626
+ self .waiting_requests .extend (tasks )
621
627
else :
622
- if len (self .waiting_requests ):
623
- llm_logger .info (f"Waiting for resource for task { tasks [0 ].request_id } " )
624
- self .waiting_requests .extend (tasks )
625
- else :
626
- new_waiting = []
627
- for task in tasks :
628
- if self .resource_manager .is_resource_sufficient (task .prompt_token_ids_len ):
629
- self .insert_tasks ([task ])
630
- else :
631
- new_waiting .append (task )
632
-
633
- if new_waiting :
634
- self .waiting_requests .extend (new_waiting )
635
- llm_logger .info (f"Added { len (new_waiting )} tasks to waiting queue" )
628
+ new_waiting = []
629
+ for task in tasks :
630
+ if self .resource_manager .is_resource_sufficient (task .prompt_token_ids_len ):
631
+ self .insert_tasks ([task ])
632
+ else :
633
+ new_waiting .append (task )
634
+
635
+ if new_waiting :
636
+ self .waiting_requests .extend (new_waiting )
637
+ llm_logger .info (f"Added { len (new_waiting )} tasks to waiting queue" )
636
638
637
639
else :
638
640
time .sleep (0.001 )
@@ -842,7 +844,6 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
842
844
is_prefill = True
843
845
self .token_processor .number_of_input_tokens += tasks [i ].prompt_token_ids_len
844
846
845
- self .split_connector .send_cache_infos (tasks , current_id )
846
847
if not is_decode :
847
848
llm_logger .info (f"Tasks are sent to engine, req_ids={ req_ids } " )
848
849
for task in tasks :
@@ -854,6 +855,8 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
854
855
self .engine_worker_queue .put_tasks ((tasks , self .resource_manager .real_bsz ))
855
856
if is_prefill and self .cfg .scheduler_config .name != "splitwise" :
856
857
self .engine_worker_queue .available_prefill_instances .put (1 )
858
+
859
+ self .split_connector .send_cache_infos (tasks , current_id )
857
860
return True
858
861
859
862
def task_is_finished (self , index ):
@@ -1017,13 +1020,16 @@ def _exit_sub_services(self):
1017
1020
except Exception as e :
1018
1021
print (f"Error extracting sub services: { e } " )
1019
1022
1020
- self .engine_worker_queue .cleanup ()
1023
+
1024
+ for worker_queue in self .engine_worker_queue_server :
1025
+ worker_queue .cleanup ()
1021
1026
if hasattr (self , "send_response_server" ) and self .send_response_server is not None :
1022
1027
self .send_response_server .close ()
1023
1028
if hasattr (self , "recv_request_server" ) and self .recv_request_server is not None :
1024
1029
self .recv_request_server .close ()
1025
1030
if hasattr (self , "recv_control_cmd_server" ) and self .recv_control_cmd_server is not None :
1026
1031
self .recv_control_cmd_server .close ()
1032
+
1027
1033
if hasattr (self , "dp_processed" ):
1028
1034
for p in self .dp_processed :
1029
1035
p .join ()
@@ -1325,15 +1331,20 @@ def start_queue_service(self):
1325
1331
"""
1326
1332
start queue service for engine worker communication
1327
1333
"""
1328
- address = (self .cfg .master_ip , self .cfg .engine_worker_queue_port )
1334
+
1335
+ self .engine_worker_queue_server = list ()
1329
1336
if self .cfg .host_ip == self .cfg .master_ip or self .cfg .master_ip == "0.0.0.0" :
1330
- llm_logger .info (f"Starting engine worker queue server service at { address } " )
1331
- self .engine_worker_queue_server = EngineWorkerQueue (
1332
- address = address ,
1333
- is_server = True ,
1334
- num_client = self .cfg .tensor_parallel_size ,
1335
- local_data_parallel_size = self .cfg .parallel_config .data_parallel_size ,
1336
- )
1337
+ for i in range (self .cfg .parallel_config .data_parallel_size // self .cfg .nnode ):
1338
+ address = (self .cfg .master_ip , self .cfg .engine_worker_queue_port + i )
1339
+ llm_logger .info (f"Starting engine worker queue service at { address } " )
1340
+ self .engine_worker_queue_server .append (
1341
+ EngineWorkerQueue (
1342
+ address = address ,
1343
+ is_server = True ,
1344
+ num_client = self .cfg .tensor_parallel_size ,
1345
+ local_data_parallel_size = self .cfg .parallel_config .data_parallel_size ,
1346
+ )
1347
+ )
1337
1348
1338
1349
if self .cfg .cache_config .enable_prefix_caching or self .cfg .splitwise_role != "mixed" :
1339
1350
self .cache_task_queue = EngineCacheQueue (
@@ -1348,6 +1359,7 @@ def start_queue_service(self):
1348
1359
local_data_parallel_size = self .cfg .parallel_config .data_parallel_size ,
1349
1360
)
1350
1361
1362
+ address = (self .cfg .master_ip , self .cfg .engine_worker_queue_port )
1351
1363
self .engine_worker_queue = EngineWorkerQueue (
1352
1364
address = address ,
1353
1365
is_server = False ,
0 commit comments