Skip to content

Commit b23af29

Browse files
zeroRainsgongshaotianDDDivano
authored
Launch expert_service before kv_cache initialization in worker_process (#3045)
* launch expert_service before kv_cache initialization * add two signal make sure model loading and expert_service lauching finished * fix the EP bug * fix ep * update launching way * fix ep * update * roback ep * pre-commit all files --------- Co-authored-by: RAM <[email protected]> Co-authored-by: Divano <[email protected]>
1 parent c27a3dc commit b23af29

File tree

6 files changed

+175
-100
lines changed

6 files changed

+175
-100
lines changed

fastdeploy/engine/engine.py

Lines changed: 126 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,42 @@ def start(self, api_server_pid=None):
196196
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
197197
pid_suffix=self.ipc_signal_suffix,
198198
)
199-
self.launched_cache_manager_signal.value[0] = 1
200199

201200
self.worker_proc = self._start_worker_service()
202201
console_logger.info("Waiting worker processes ready...")
203202
time.sleep(5)
204203
self.worker_init_status = dict()
205-
if not self.check_worker_initialize_status():
204+
205+
result_container = {}
206+
207+
def check_worker_initialize_status_func(res: dict):
208+
res["worker_is_alive"] = True
209+
if not self.check_worker_initialize_status():
210+
console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
211+
res["worker_is_alive"] = False
212+
213+
self.check_worker_initialize_status_func_thread = threading.Thread(
214+
target=check_worker_initialize_status_func, args=(result_container,), daemon=True
215+
)
216+
self.check_worker_initialize_status_func_thread.start()
217+
218+
# Wait model loading
219+
while self.loaded_model_signal.value[0] == 0:
220+
# Make sure worker process is alive
221+
if not self.check_worker_initialize_status_func_thread.is_alive():
222+
return False
223+
time.sleep(1)
224+
225+
if self.do_profile:
226+
self._stop_profile()
227+
# Launch components: scheduler, cache_manager, expert_service et.al.
228+
self.launch_components()
229+
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
230+
self.launched_cache_manager_signal.value[0] = 1
231+
232+
# Worker launched
233+
self.check_worker_initialize_status_func_thread.join()
234+
if not result_container["worker_is_alive"]:
206235
console_logger.error("Failed to launch worker processes, check log/workerlog.* for more details.")
207236
return False
208237

@@ -214,68 +243,6 @@ def start(self, api_server_pid=None):
214243
self._del_warmup_token_processor()
215244
console_logger.info("Warmup finished")
216245

217-
self.token_processor.tasks_queue = self.engine_worker_queue
218-
219-
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
220-
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
221-
else:
222-
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
223-
self.insert_task_to_worker_thread.start()
224-
225-
if self.api_server_pid is not None:
226-
self.insert_task_to_scheduler_thread = threading.Thread(
227-
target=self._insert_zmq_task_to_scheduler, daemon=True
228-
)
229-
self.insert_task_to_scheduler_thread.start()
230-
231-
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
232-
self.receive_output_thread.start()
233-
234-
# Start TokenProcessor thread
235-
self.token_processor.run()
236-
237-
if self.cfg.splitwise_role != "mixed":
238-
# 单机逻辑
239-
self.engine_worker_queue.available_prefill_instances.put(1)
240-
self.split_mode_get_tasks()
241-
if self.cfg.scheduler_config.name == "splitwise":
242-
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
243-
self.splitwise_receive_thread.daemon = True
244-
self.splitwise_receive_thread.start()
245-
246-
self.cfg.init_cache_info()
247-
248-
role = self.cfg.splitwise_role
249-
host_ip = self.cfg.host_ip
250-
disaggregate = self.cfg.disaggregate_info
251-
if self.cfg.scheduler_config.name == "splitwise":
252-
self.scheduler.start(role, host_ip, disaggregate)
253-
254-
time.sleep(1)
255-
256-
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
257-
self.dp_processed = []
258-
for i in range(
259-
1,
260-
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
261-
):
262-
time.sleep(1)
263-
self.dp_processed.append(
264-
multiprocessing.Process(
265-
target=start_expert_service,
266-
args=(
267-
self.cfg,
268-
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
269-
self.ipc_signal_suffix,
270-
),
271-
)
272-
)
273-
llm_logger.info(
274-
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
275-
+ f" data parallel id {i}"
276-
)
277-
self.dp_processed[-1].start()
278-
279246
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
280247
return True
281248

@@ -909,7 +876,7 @@ def _init_worker_signals(self):
909876
create=True,
910877
)
911878

912-
# exist_task_signal 用于各worker进程感知是否有新Task需要处理
879+
# exist_task_signal: Used by each worker process to detect whether there is a new task to be processed
913880
exist_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
914881
self.exist_task_signal = IPCSignal(
915882
name="exist_task_signal",
@@ -919,7 +886,7 @@ def _init_worker_signals(self):
919886
create=True,
920887
)
921888

922-
# exist_swapped_task_signal 用于engine感知worker中是否存在swapped task
889+
# exist_swapped_task_signal: Used by the engine to detect whether there is a swapped task in the worker
923890
exist_swapped_task_signal_data = np.zeros([self.cfg.parallel_config.data_parallel_size], dtype=np.int32)
924891
self.exist_swapped_task_signal = IPCSignal(
925892
name="exist_swapped_task_signal",
@@ -929,7 +896,7 @@ def _init_worker_signals(self):
929896
create=True,
930897
)
931898

932-
# exist_prefill_task_signal 用于各worker进程感知是否进行prefill
899+
# exist_prefill_task_signal: Used by each worker process to detect whether to prefill
933900
exist_prefill_task_signal_data = np.zeros([1], dtype=np.int32)
934901
self.exist_prefill_task_signal = IPCSignal(
935902
name="exist_prefill_task_signal",
@@ -939,7 +906,7 @@ def _init_worker_signals(self):
939906
create=True,
940907
)
941908

942-
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
909+
# launched_cache_manager_signal: Used to detect whether the engine has started cache_manager
943910
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
944911
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
945912
self.launched_cache_manager_signal = IPCSignal(
@@ -950,7 +917,30 @@ def _init_worker_signals(self):
950917
create=True,
951918
)
952919

953-
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
920+
# launched_expert_service_signal: Used to sense whether each expet_servic is started successfully
921+
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
922+
launched_expert_service_signal_data = np.zeros(
923+
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
924+
)
925+
self.launched_expert_service_signal = IPCSignal(
926+
name="launched_expert_service_signal",
927+
array=launched_expert_service_signal_data,
928+
dtype=np.int32,
929+
suffix=self.ipc_signal_suffix,
930+
create=True,
931+
)
932+
933+
# loaded_model_signal: Used to detect whether each worker has completed model loading
934+
loaded_model_signal_data = np.zeros([1], dtype=np.int32)
935+
self.loaded_model_signal = IPCSignal(
936+
name="loaded_model_signal",
937+
array=loaded_model_signal_data,
938+
dtype=np.int32,
939+
suffix=self.ipc_signal_suffix,
940+
create=True,
941+
)
942+
943+
# worker_live_signal: Used by the engine to detect whether each worker process is alive and record the time of each step
954944
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
955945
self.worker_healthy_live_signal = IPCSignal(
956946
name="worker_healthy_live_signal",
@@ -1187,7 +1177,7 @@ def generate(self, prompts, stream):
11871177
llm_logger.error(f"Error happend while adding request, details={e}")
11881178
raise EngineError(str(e), error_code=400)
11891179

1190-
# 获取当前请求的结果
1180+
# Get the result of the current request
11911181
for result in self._get_generated_tokens(req_id):
11921182
is_end = result.finished
11931183
if stream and not is_end:
@@ -1231,7 +1221,6 @@ def _stop_profile(self):
12311221
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
12321222
pid_suffix=self.ipc_signal_suffix,
12331223
)
1234-
self.launched_cache_manager_signal.value[0] = 1
12351224

12361225
def check_health(self, time_interval_threashold=30):
12371226
"""
@@ -1245,6 +1234,72 @@ def check_health(self, time_interval_threashold=30):
12451234

12461235
return True, ""
12471236

1237+
def launch_components(self):
1238+
self.token_processor.tasks_queue = self.engine_worker_queue
1239+
1240+
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
1241+
self.insert_task_to_worker_thread = threading.Thread(target=self._scheduler_task_to_worker_v1, daemon=True)
1242+
else:
1243+
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, daemon=True)
1244+
self.insert_task_to_worker_thread.start()
1245+
1246+
if self.api_server_pid is not None:
1247+
self.insert_task_to_scheduler_thread = threading.Thread(
1248+
target=self._insert_zmq_task_to_scheduler, daemon=True
1249+
)
1250+
self.insert_task_to_scheduler_thread.start()
1251+
1252+
self.receive_output_thread = threading.Thread(target=self._zmq_send_generated_tokens, daemon=True)
1253+
self.receive_output_thread.start()
1254+
1255+
# Start TokenProcessor thread
1256+
self.token_processor.run()
1257+
1258+
if self.cfg.splitwise_role != "mixed":
1259+
# 单机逻辑
1260+
self.engine_worker_queue.available_prefill_instances.put(1)
1261+
self.split_mode_get_tasks()
1262+
if self.cfg.scheduler_config.name == "splitwise":
1263+
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
1264+
self.splitwise_receive_thread.daemon = True
1265+
self.splitwise_receive_thread.start()
1266+
1267+
self.cfg.init_cache_info()
1268+
1269+
role = self.cfg.splitwise_role
1270+
host_ip = self.cfg.host_ip
1271+
disaggregate = self.cfg.disaggregate_info
1272+
if self.cfg.scheduler_config.name == "splitwise":
1273+
self.scheduler.start(role, host_ip, disaggregate)
1274+
1275+
time.sleep(1)
1276+
expert_service_nums = self.cfg.parallel_config.data_parallel_size // self.cfg.nnode
1277+
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
1278+
self.dp_processed = []
1279+
for i in range(
1280+
1,
1281+
expert_service_nums,
1282+
):
1283+
time.sleep(1)
1284+
self.dp_processed.append(
1285+
multiprocessing.Process(
1286+
target=start_expert_service,
1287+
args=(
1288+
self.cfg,
1289+
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
1290+
self.ipc_signal_suffix,
1291+
),
1292+
)
1293+
)
1294+
llm_logger.info(
1295+
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
1296+
+ f" data parallel id {i}"
1297+
)
1298+
self.dp_processed[-1].start()
1299+
for i in range(1, expert_service_nums):
1300+
while self.launched_expert_service_signal.value[i] == 0:
1301+
time.sleep(10)
1302+
12481303
def check_worker_initialize_status(self):
12491304
"""
12501305
Check the initlialize status of workers by stdout logging
@@ -1270,10 +1325,6 @@ def detect_thread():
12701325

12711326
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
12721327
self.checking_worker_status_thread.start()
1273-
checking_worker_init_kv_cache_status_thread = None
1274-
if self.do_profile:
1275-
checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True)
1276-
checking_worker_init_kv_cache_status_thread.start()
12771328

12781329
# display weight loadding progress
12791330
with tqdm(total=100, desc="Loading Weights") as pbar:
@@ -1304,8 +1355,6 @@ def detect_thread():
13041355
self.worker_init_status["finished"] = True
13051356
try:
13061357
self.checking_worker_status_thread.join(timeout=1)
1307-
if checking_worker_init_kv_cache_status_thread is not None:
1308-
checking_worker_init_kv_cache_status_thread.join(timeout=1)
13091358
except Exception:
13101359
pass
13111360
return True

fastdeploy/engine/expert_service.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import numpy as np
2727

2828
from fastdeploy.engine.resource_manager import ResourceManager
29-
from fastdeploy.inter_communicator import EngineWorkerQueue
29+
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
3030
from fastdeploy.metrics.metrics import main_process_metrics
3131
from fastdeploy.output.token_processor import TokenProcessor
3232
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
@@ -127,7 +127,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id):
127127
cache_config=self.cfg.cache_config,
128128
tensor_parallel_size=self.cfg.tensor_parallel_size,
129129
device_ids=self.cfg.local_device_ids,
130-
pod_ip=self.cfg.pod_ips[0],
130+
pod_ip=self.cfg.master_ip,
131131
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
132132
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
133133
)
@@ -141,16 +141,29 @@ def start(self, ipc_signal_suffix, local_data_parallel_id):
141141
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(local_data_parallel_id + int(self.cfg.engine_worker_queue_port))
142142

143143
self.token_processor.run()
144-
145144
self.cfg.init_cache_info()
146-
147145
role = self.cfg.splitwise_role
148146
host_ip = self.cfg.host_ip
149147
disaggregate = self.cfg.disaggregate_info
150148
self.scheduler.start(role, host_ip, disaggregate)
151149
self.cfg.print()
152150

153-
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
151+
launched_expert_service_signal_data = np.zeros(
152+
shape=[self.cfg.parallel_config.data_parallel_size // self.cfg.nnode], dtype=np.int32
153+
)
154+
self.launched_expert_service_signal = IPCSignal(
155+
name="launched_expert_service_signal",
156+
array=launched_expert_service_signal_data,
157+
dtype=np.int32,
158+
suffix=ipc_signal_suffix,
159+
create=False,
160+
)
161+
local_rank = local_data_parallel_id % self.cfg.worker_num_per_node
162+
self.launched_expert_service_signal.value[local_rank] = 1
163+
164+
console_logger.info(
165+
f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds."
166+
)
154167
return True
155168

156169
def _insert_task_to_worker(self):

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,13 @@ def _prepare_decode_task(self, request):
9797

9898
def _prepare_preempt_task(self, request):
9999
return ScheduledPreemptTask(idx=request.idx, request_id=request.request_id)
100-
100+
101101
def reschedule_preempt_task(self, request_id):
102102
with self.lock:
103103
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
104104
request = self.requests[request_id]
105105
self.waiting.appendleft(request)
106-
self.to_be_rescheduled_request_id_set.remove(request_id)
106+
self.to_be_rescheduled_request_id_set.remove(request_id)
107107

108108
def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_reqs):
109109
can_schedule = True
@@ -422,9 +422,15 @@ def finish_requests(self, request_ids: Union[str, Iterable[str]]):
422422
self.running.remove(request)
423423
request.status = RequestStatus.FINISHED
424424
self._free_blocks(request)
425-
if request.request_id in self.to_be_rescheduled_request_id_set: # finished after preempted, blocks have been recycled.
426-
self.to_be_rescheduled_request_id_set.remove(request.request_id) # just remove from to_be_rescheduled_request_id_set
427-
if request in self.waiting: # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
425+
if (
426+
request.request_id in self.to_be_rescheduled_request_id_set
427+
): # finished after preempted, blocks have been recycled.
428+
self.to_be_rescheduled_request_id_set.remove(
429+
request.request_id
430+
) # just remove from to_be_rescheduled_request_id_set
431+
if (
432+
request in self.waiting
433+
): # after finished, this request still scheduled from preempted to waiting, unexpected error, should not be here
428434
raise RuntimeError(f"request {request.request_id} scheduled into waiting list, after finished")
429435

430436
self.tasks_list[request.idx] = None

fastdeploy/output/token_processor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,14 @@ def _process_batch_output(self):
296296
else:
297297
batch = self.output_tokens[1, 0]
298298
tokens = tokens[2 : batch + 2]
299-
299+
300300
batch_result = list()
301301
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
302302
need_to_be_reschedule_req_ids = list(self.resource_manager.to_be_rescheduled_request_id_set)
303303
for request_id in need_to_be_reschedule_req_ids:
304-
if self.resource_manager.requests[request_id].idx >= (batch - 1): # No more token generated for preempted request
304+
if self.resource_manager.requests[request_id].idx >= (
305+
batch - 1
306+
): # No more token generated for preempted request
305307
self.resource_manager.reschedule_preempt_task(request_id)
306308
for i in range(batch):
307309
if self.resource_manager.stop_flags[i]:

0 commit comments

Comments
 (0)