@@ -196,13 +196,42 @@ def start(self, api_server_pid=None):
196
196
engine_worker_queue_port = self .cfg .engine_worker_queue_port ,
197
197
pid_suffix = self .ipc_signal_suffix ,
198
198
)
199
- self .launched_cache_manager_signal .value [0 ] = 1
200
199
201
200
self .worker_proc = self ._start_worker_service ()
202
201
console_logger .info ("Waiting worker processes ready..." )
203
202
time .sleep (5 )
204
203
self .worker_init_status = dict ()
205
- if not self .check_worker_initialize_status ():
204
+
205
+ result_container = {}
206
+
207
+ def check_worker_initialize_status_func (res : dict ):
208
+ res ["worker_is_alive" ] = True
209
+ if not self .check_worker_initialize_status ():
210
+ console_logger .error ("Failed to launch worker processes, check log/workerlog.* for more details." )
211
+ res ["worker_is_alive" ] = False
212
+
213
+ self .check_worker_initialize_status_func_thread = threading .Thread (
214
+ target = check_worker_initialize_status_func , args = (result_container ,), daemon = True
215
+ )
216
+ self .check_worker_initialize_status_func_thread .start ()
217
+
218
+ # Wait model loading
219
+ while self .loaded_model_signal .value [0 ] == 0 :
220
+ # Make sure worker process is alive
221
+ if not self .check_worker_initialize_status_func_thread .is_alive ():
222
+ return False
223
+ time .sleep (1 )
224
+
225
+ if self .do_profile :
226
+ self ._stop_profile ()
227
+ # Launch components: scheduler, cache_manager, expert_service et.al.
228
+ self .launch_components ()
229
+ if self .cfg .cache_config .enable_prefix_caching or self .cfg .splitwise_role != "mixed" :
230
+ self .launched_cache_manager_signal .value [0 ] = 1
231
+
232
+ # Worker launched
233
+ self .check_worker_initialize_status_func_thread .join ()
234
+ if not result_container ["worker_is_alive" ]:
206
235
console_logger .error ("Failed to launch worker processes, check log/workerlog.* for more details." )
207
236
return False
208
237
@@ -214,68 +243,6 @@ def start(self, api_server_pid=None):
214
243
self ._del_warmup_token_processor ()
215
244
console_logger .info ("Warmup finished" )
216
245
217
- self .token_processor .tasks_queue = self .engine_worker_queue
218
-
219
- if envs .ENABLE_V1_KVCACHE_SCHEDULER :
220
- self .insert_task_to_worker_thread = threading .Thread (target = self ._scheduler_task_to_worker_v1 , daemon = True )
221
- else :
222
- self .insert_task_to_worker_thread = threading .Thread (target = self ._insert_task_to_worker , daemon = True )
223
- self .insert_task_to_worker_thread .start ()
224
-
225
- if self .api_server_pid is not None :
226
- self .insert_task_to_scheduler_thread = threading .Thread (
227
- target = self ._insert_zmq_task_to_scheduler , daemon = True
228
- )
229
- self .insert_task_to_scheduler_thread .start ()
230
-
231
- self .receive_output_thread = threading .Thread (target = self ._zmq_send_generated_tokens , daemon = True )
232
- self .receive_output_thread .start ()
233
-
234
- # Start TokenProcessor thread
235
- self .token_processor .run ()
236
-
237
- if self .cfg .splitwise_role != "mixed" :
238
- # 单机逻辑
239
- self .engine_worker_queue .available_prefill_instances .put (1 )
240
- self .split_mode_get_tasks ()
241
- if self .cfg .scheduler_config .name == "splitwise" :
242
- self .splitwise_receive_thread = threading .Thread (target = self .split_connector .start_receiver , args = ())
243
- self .splitwise_receive_thread .daemon = True
244
- self .splitwise_receive_thread .start ()
245
-
246
- self .cfg .init_cache_info ()
247
-
248
- role = self .cfg .splitwise_role
249
- host_ip = self .cfg .host_ip
250
- disaggregate = self .cfg .disaggregate_info
251
- if self .cfg .scheduler_config .name == "splitwise" :
252
- self .scheduler .start (role , host_ip , disaggregate )
253
-
254
- time .sleep (1 )
255
-
256
- if self .cfg .parallel_config .enable_expert_parallel and self .cfg .parallel_config .data_parallel_size > 1 :
257
- self .dp_processed = []
258
- for i in range (
259
- 1 ,
260
- self .cfg .parallel_config .data_parallel_size // self .cfg .nnode ,
261
- ):
262
- time .sleep (1 )
263
- self .dp_processed .append (
264
- multiprocessing .Process (
265
- target = start_expert_service ,
266
- args = (
267
- self .cfg ,
268
- i + self .cfg .node_rank * self .cfg .worker_num_per_node ,
269
- self .ipc_signal_suffix ,
270
- ),
271
- )
272
- )
273
- llm_logger .info (
274
- f"Engine is initialized successfully with { self .cfg .tensor_parallel_size } "
275
- + f" data parallel id { i } "
276
- )
277
- self .dp_processed [- 1 ].start ()
278
-
279
246
console_logger .info (f"Worker processes are launched with { time .time () - start_time } seconds." )
280
247
return True
281
248
@@ -909,7 +876,7 @@ def _init_worker_signals(self):
909
876
create = True ,
910
877
)
911
878
912
- # exist_task_signal 用于各worker进程感知是否有新Task需要处理
879
+ # exist_task_signal: Used by each worker process to detect whether there is a new task to be processed
913
880
exist_task_signal_data = np .zeros ([self .cfg .parallel_config .data_parallel_size ], dtype = np .int32 )
914
881
self .exist_task_signal = IPCSignal (
915
882
name = "exist_task_signal" ,
@@ -919,7 +886,7 @@ def _init_worker_signals(self):
919
886
create = True ,
920
887
)
921
888
922
- # exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
889
+ # exist_swapped_task_signal: Used by the engine to detect whether there is a swapped task in the worker
923
890
exist_swapped_task_signal_data = np .zeros ([self .cfg .parallel_config .data_parallel_size ], dtype = np .int32 )
924
891
self .exist_swapped_task_signal = IPCSignal (
925
892
name = "exist_swapped_task_signal" ,
@@ -929,7 +896,7 @@ def _init_worker_signals(self):
929
896
create = True ,
930
897
)
931
898
932
- # exist_prefill_task_signal 用于各worker进程感知是否进行prefill
899
+ # exist_prefill_task_signal: Used by each worker process to detect whether to prefill
933
900
exist_prefill_task_signal_data = np .zeros ([1 ], dtype = np .int32 )
934
901
self .exist_prefill_task_signal = IPCSignal (
935
902
name = "exist_prefill_task_signal" ,
@@ -939,7 +906,7 @@ def _init_worker_signals(self):
939
906
create = True ,
940
907
)
941
908
942
- # launched_cache_manager_signal 用于感知engine是否启动了cache_manager
909
+ # launched_cache_manager_signal: Used to detect whether the engine has started cache_manager
943
910
if self .cfg .cache_config .enable_prefix_caching or self .cfg .splitwise_role != "mixed" :
944
911
launched_cache_manager_signal_data = np .zeros ([1 ], dtype = np .int32 )
945
912
self .launched_cache_manager_signal = IPCSignal (
@@ -950,7 +917,30 @@ def _init_worker_signals(self):
950
917
create = True ,
951
918
)
952
919
953
- # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
920
+ # launched_expert_service_signal: Used to sense whether each expet_servic is started successfully
921
+ if self .cfg .parallel_config .enable_expert_parallel and self .cfg .parallel_config .data_parallel_size > 1 :
922
+ launched_expert_service_signal_data = np .zeros (
923
+ shape = [self .cfg .parallel_config .data_parallel_size // self .cfg .nnode ], dtype = np .int32
924
+ )
925
+ self .launched_expert_service_signal = IPCSignal (
926
+ name = "launched_expert_service_signal" ,
927
+ array = launched_expert_service_signal_data ,
928
+ dtype = np .int32 ,
929
+ suffix = self .ipc_signal_suffix ,
930
+ create = True ,
931
+ )
932
+
933
+ # loaded_model_signal: Used to detect whether each worker has completed model loading
934
+ loaded_model_signal_data = np .zeros ([1 ], dtype = np .int32 )
935
+ self .loaded_model_signal = IPCSignal (
936
+ name = "loaded_model_signal" ,
937
+ array = loaded_model_signal_data ,
938
+ dtype = np .int32 ,
939
+ suffix = self .ipc_signal_suffix ,
940
+ create = True ,
941
+ )
942
+
943
+ # worker_live_signal: Used by the engine to detect whether each worker process is alive and record the time of each step
954
944
worker_healthy_live_recorded_time_array = np .zeros (shape = [self .cfg .worker_num_per_node ], dtype = np .int32 )
955
945
self .worker_healthy_live_signal = IPCSignal (
956
946
name = "worker_healthy_live_signal" ,
@@ -1187,7 +1177,7 @@ def generate(self, prompts, stream):
1187
1177
llm_logger .error (f"Error happend while adding request, details={ e } " )
1188
1178
raise EngineError (str (e ), error_code = 400 )
1189
1179
1190
- # 获取当前请求的结果
1180
+ # Get the result of the current request
1191
1181
for result in self ._get_generated_tokens (req_id ):
1192
1182
is_end = result .finished
1193
1183
if stream and not is_end :
@@ -1231,7 +1221,6 @@ def _stop_profile(self):
1231
1221
engine_worker_queue_port = self .cfg .engine_worker_queue_port ,
1232
1222
pid_suffix = self .ipc_signal_suffix ,
1233
1223
)
1234
- self .launched_cache_manager_signal .value [0 ] = 1
1235
1224
1236
1225
def check_health (self , time_interval_threashold = 30 ):
1237
1226
"""
@@ -1245,6 +1234,72 @@ def check_health(self, time_interval_threashold=30):
1245
1234
1246
1235
return True , ""
1247
1236
1237
+ def launch_components (self ):
1238
+ self .token_processor .tasks_queue = self .engine_worker_queue
1239
+
1240
+ if envs .ENABLE_V1_KVCACHE_SCHEDULER :
1241
+ self .insert_task_to_worker_thread = threading .Thread (target = self ._scheduler_task_to_worker_v1 , daemon = True )
1242
+ else :
1243
+ self .insert_task_to_worker_thread = threading .Thread (target = self ._insert_task_to_worker , daemon = True )
1244
+ self .insert_task_to_worker_thread .start ()
1245
+
1246
+ if self .api_server_pid is not None :
1247
+ self .insert_task_to_scheduler_thread = threading .Thread (
1248
+ target = self ._insert_zmq_task_to_scheduler , daemon = True
1249
+ )
1250
+ self .insert_task_to_scheduler_thread .start ()
1251
+
1252
+ self .receive_output_thread = threading .Thread (target = self ._zmq_send_generated_tokens , daemon = True )
1253
+ self .receive_output_thread .start ()
1254
+
1255
+ # Start TokenProcessor thread
1256
+ self .token_processor .run ()
1257
+
1258
+ if self .cfg .splitwise_role != "mixed" :
1259
+ # 单机逻辑
1260
+ self .engine_worker_queue .available_prefill_instances .put (1 )
1261
+ self .split_mode_get_tasks ()
1262
+ if self .cfg .scheduler_config .name == "splitwise" :
1263
+ self .splitwise_receive_thread = threading .Thread (target = self .split_connector .start_receiver , args = ())
1264
+ self .splitwise_receive_thread .daemon = True
1265
+ self .splitwise_receive_thread .start ()
1266
+
1267
+ self .cfg .init_cache_info ()
1268
+
1269
+ role = self .cfg .splitwise_role
1270
+ host_ip = self .cfg .host_ip
1271
+ disaggregate = self .cfg .disaggregate_info
1272
+ if self .cfg .scheduler_config .name == "splitwise" :
1273
+ self .scheduler .start (role , host_ip , disaggregate )
1274
+
1275
+ time .sleep (1 )
1276
+ expert_service_nums = self .cfg .parallel_config .data_parallel_size // self .cfg .nnode
1277
+ if self .cfg .parallel_config .enable_expert_parallel and self .cfg .parallel_config .data_parallel_size > 1 :
1278
+ self .dp_processed = []
1279
+ for i in range (
1280
+ 1 ,
1281
+ expert_service_nums ,
1282
+ ):
1283
+ time .sleep (1 )
1284
+ self .dp_processed .append (
1285
+ multiprocessing .Process (
1286
+ target = start_expert_service ,
1287
+ args = (
1288
+ self .cfg ,
1289
+ i + self .cfg .node_rank * self .cfg .worker_num_per_node ,
1290
+ self .ipc_signal_suffix ,
1291
+ ),
1292
+ )
1293
+ )
1294
+ llm_logger .info (
1295
+ f"Engine is initialized successfully with { self .cfg .tensor_parallel_size } "
1296
+ + f" data parallel id { i } "
1297
+ )
1298
+ self .dp_processed [- 1 ].start ()
1299
+ for i in range (1 , expert_service_nums ):
1300
+ while self .launched_expert_service_signal .value [i ] == 0 :
1301
+ time .sleep (10 )
1302
+
1248
1303
def check_worker_initialize_status (self ):
1249
1304
"""
1250
1305
Check the initlialize status of workers by stdout logging
@@ -1270,10 +1325,6 @@ def detect_thread():
1270
1325
1271
1326
self .checking_worker_status_thread = threading .Thread (target = detect_thread , daemon = True )
1272
1327
self .checking_worker_status_thread .start ()
1273
- checking_worker_init_kv_cache_status_thread = None
1274
- if self .do_profile :
1275
- checking_worker_init_kv_cache_status_thread = threading .Thread (target = self ._stop_profile , daemon = True )
1276
- checking_worker_init_kv_cache_status_thread .start ()
1277
1328
1278
1329
# display weight loadding progress
1279
1330
with tqdm (total = 100 , desc = "Loading Weights" ) as pbar :
@@ -1304,8 +1355,6 @@ def detect_thread():
1304
1355
self .worker_init_status ["finished" ] = True
1305
1356
try :
1306
1357
self .checking_worker_status_thread .join (timeout = 1 )
1307
- if checking_worker_init_kv_cache_status_thread is not None :
1308
- checking_worker_init_kv_cache_status_thread .join (timeout = 1 )
1309
1358
except Exception :
1310
1359
pass
1311
1360
return True
0 commit comments