Skip to content

Commit 9f99718

Browse files
rainyflyroot
andauthored
[Feature] Support ep pd with external module (#3194)
* Support external module * Support external module * Support external module * Support external module * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * refactor code to make it more clear * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix according to review * fix bug * fix bug * fix bug * merge --------- Co-authored-by: root <[email protected]>
1 parent 0443587 commit 9f99718

File tree

15 files changed

+876
-218
lines changed

15 files changed

+876
-218
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,16 @@ def __init__(
142142

143143
self.gpu_id = gpu_id
144144
self.cache_info = dict()
145-
self.dp_rank_id = self.rank + local_data_parallel_id * self.nranks
145+
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
146146

147147
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
148148
layerwise_send_cache_thread.daemon = True
149149
layerwise_send_cache_thread.start()
150150

151+
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
152+
connect_rdma_thread.daemon = True
153+
connect_rdma_thread.start()
154+
151155
logger.info(f"cache messager init finished, use {transfer_protocol}")
152156

153157
def _prefill_layerwise_send_cache_thread(self):
@@ -160,29 +164,29 @@ def _prefill_layerwise_send_cache_thread(self):
160164
prefilled_layer_idx_data = np.zeros(shape=[1], dtype=np.int32)
161165
try:
162166
step_shm_value = IPCSignal(
163-
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
167+
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
164168
array=prefilled_step_idx_data,
165169
dtype=np.int32,
166170
suffix=self.gpu_id,
167171
create=True,
168172
)
169173
layer_shm_value = IPCSignal(
170-
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
174+
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
171175
array=prefilled_layer_idx_data,
172176
dtype=np.int32,
173177
suffix=self.gpu_id,
174178
create=True,
175179
)
176180
except:
177181
step_shm_value = IPCSignal(
178-
name=f"splitwise_complete_prefilled_step_{self.dp_rank_id}",
182+
name=f"splitwise_complete_prefilled_step_{self.rank_id}",
179183
array=prefilled_step_idx_data,
180184
dtype=np.int32,
181185
suffix=self.gpu_id,
182186
create=False,
183187
)
184188
layer_shm_value = IPCSignal(
185-
name=f"splitwise_complete_prefilled_layer_{self.dp_rank_id}",
189+
name=f"splitwise_complete_prefilled_layer_{self.rank_id}",
186190
array=prefilled_layer_idx_data,
187191
dtype=np.int32,
188192
suffix=self.gpu_id,
@@ -310,3 +314,22 @@ def _prefill_layerwise_send_cache_thread(self):
310314

311315
except Exception as e:
312316
logger.error(f"prefill layerwise send cache thread has exception: {e}")
317+
318+
def _handle_connect_task(self):
319+
while True:
320+
try:
321+
task = self.engine_worker_queue.get_connect_rdma_task()
322+
if task is None:
323+
time.sleep(0.001)
324+
continue
325+
logger.info(f"_handle_connect_task recv task: {task}")
326+
task_id = task["task_id"]
327+
ip, rdma_port = task["ip"], task["rdma_port"]
328+
status = self.messager["rdma"].connect(ip, rdma_port)
329+
if not status:
330+
response = {"task_id": task_id, "success": False}
331+
else:
332+
response = {"task_id": task_id, "success": True}
333+
self.engine_worker_queue.put_connect_rdma_task_response(response)
334+
except Exception as e:
335+
logger.error(f"handle_connect_task has exception: {e}")

fastdeploy/engine/args_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,7 @@ def create_scheduler_config(self) -> SchedulerConfig:
820820
"max_num_partial_prefills",
821821
"max_long_partial_prefills",
822822
"long_prefill_token_threshold",
823+
"splitwise_role"
823824
]
824825

825826
all = asdict(self)

fastdeploy/engine/engine.py

Lines changed: 74 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,14 @@
4747
EngineCacheQueue,
4848
EngineWorkerQueue,
4949
IPCSignal,
50-
ZmqClient,
50+
ZmqIpcServer,
51+
ZmqTcpServer,
5152
)
5253
from fastdeploy.metrics.metrics import main_process_metrics
5354
from fastdeploy.metrics.trace_util import start_span, start_span_request
5455
from fastdeploy.model_executor.guided_decoding import schema_checker
5556
from fastdeploy.output.token_processor import TokenProcessor, WarmUpTokenProcessor
57+
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
5658
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
5759
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
5860

@@ -179,11 +181,64 @@ def start(self, api_server_pid=None):
179181
self.data_processor = self.input_processor.create_processor()
180182

181183
if api_server_pid is not None:
182-
self.zmq_server = ZmqClient(name=api_server_pid, mode=zmq.PULL)
183-
self.zmq_server.start_server()
184-
self.zmq_server.create_router()
184+
if envs.FD_ENABLE_INTERNAL_ADAPTER:
185+
self.recv_request_server = ZmqTcpServer(port=envs.FD_ZMQ_RECV_REQUEST_SERVER_PORT, mode=zmq.PULL)
186+
self.send_response_server = ZmqTcpServer(port=envs.FD_ZMQ_SEND_RESPONSE_SERVER_PORT, mode=zmq.ROUTER)
187+
self.external_adapter = InternalAdapter(
188+
cfg=self.cfg, engine=self, dp_rank=self.cfg.node_rank * self.cfg.worker_num_per_node
189+
)
190+
else:
191+
self.recv_request_server = ZmqIpcServer(name=api_server_pid, mode=zmq.PULL)
192+
self.send_response_server = ZmqIpcServer(name=api_server_pid, mode=zmq.ROUTER)
185193
time.sleep(3)
186194

195+
self.cfg.init_cache_info()
196+
197+
role = self.cfg.splitwise_role
198+
host_ip = self.cfg.host_ip
199+
disaggregate = self.cfg.disaggregate_info
200+
request_queues_for_dp_ipc = (
201+
None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp
202+
)
203+
result_queue_for_dp_ipc = None
204+
if self.cfg.scheduler_config.name == "splitwise":
205+
self.scheduler.start(role, host_ip, disaggregate)
206+
elif self.cfg.scheduler_config.name == "dp":
207+
request_queues_for_dp_ipc = []
208+
result_queue_for_dp_ipc = multiprocessing.Queue()
209+
for i in range(self.cfg.parallel_config.data_parallel_size):
210+
request_queues_for_dp_ipc.append(multiprocessing.Queue())
211+
self.scheduler.start(
212+
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
213+
)
214+
215+
time.sleep(1)
216+
217+
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
218+
self.dp_processed = []
219+
for i in range(
220+
1,
221+
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
222+
):
223+
time.sleep(1)
224+
self.dp_processed.append(
225+
multiprocessing.Process(
226+
target=start_expert_service,
227+
args=(
228+
self.cfg,
229+
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
230+
self.ipc_signal_suffix,
231+
request_queues_for_dp_ipc,
232+
result_queue_for_dp_ipc,
233+
),
234+
)
235+
)
236+
llm_logger.info(
237+
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
238+
+ f" data parallel id {i}"
239+
)
240+
self.dp_processed[-1].start()
241+
187242
if self.do_profile == 0 and (
188243
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
189244
):
@@ -238,44 +293,11 @@ def start(self, api_server_pid=None):
238293
# 单机逻辑
239294
self.engine_worker_queue.available_prefill_instances.put(1)
240295
self.split_mode_get_tasks()
241-
if self.cfg.scheduler_config.name == "splitwise":
296+
if self.cfg.scheduler_config.name == "splitwise" or self.cfg.scheduler_config.name == "dp":
242297
self.splitwise_receive_thread = threading.Thread(target=self.split_connector.start_receiver, args=())
243298
self.splitwise_receive_thread.daemon = True
244299
self.splitwise_receive_thread.start()
245300

246-
self.cfg.init_cache_info()
247-
248-
role = self.cfg.splitwise_role
249-
host_ip = self.cfg.host_ip
250-
disaggregate = self.cfg.disaggregate_info
251-
if self.cfg.scheduler_config.name == "splitwise":
252-
self.scheduler.start(role, host_ip, disaggregate)
253-
254-
time.sleep(1)
255-
256-
if self.cfg.parallel_config.enable_expert_parallel and self.cfg.parallel_config.data_parallel_size > 1:
257-
self.dp_processed = []
258-
for i in range(
259-
1,
260-
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
261-
):
262-
time.sleep(1)
263-
self.dp_processed.append(
264-
multiprocessing.Process(
265-
target=start_expert_service,
266-
args=(
267-
self.cfg,
268-
i + self.cfg.node_rank * self.cfg.worker_num_per_node,
269-
self.ipc_signal_suffix,
270-
),
271-
)
272-
)
273-
llm_logger.info(
274-
f"Engine is initialized successfully with {self.cfg.tensor_parallel_size}"
275-
+ f" data parallel id {i}"
276-
)
277-
self.dp_processed[-1].start()
278-
279301
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
280302
return True
281303

@@ -291,7 +313,7 @@ def _zmq_send_generated_tokens(self):
291313
time.sleep(0.005)
292314
continue
293315
for request_id, contents in results.items():
294-
self.zmq_server.send_multipart(request_id, contents)
316+
self.send_response_server.send_response(request_id, contents)
295317

296318
except Exception as e:
297319
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
@@ -415,14 +437,18 @@ def _insert_zmq_task_to_scheduler(self):
415437
if self.api_server_pid is None:
416438
return
417439

440+
if envs.FD_ENABLE_INTERNAL_ADAPTER:
441+
if self.cfg.splitwise_role == "decode":
442+
return
443+
418444
added_requests: Dict[str, int] = dict()
419445
while self.running:
420446
try:
421447
block = True if len(added_requests) == 0 else False
422448
if not self.cfg.enable_mm:
423-
err, data = self.zmq_server.receive_json_once(block)
449+
err, data = self.recv_request_server.receive_json_once(block)
424450
else:
425-
err, data = self.zmq_server.receive_pyobj_once(block)
451+
err, data = self.recv_request_server.receive_pyobj_once(block)
426452
if err is not None:
427453
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
428454
break
@@ -470,7 +496,7 @@ def _insert_zmq_task_to_scheduler(self):
470496
)
471497
# Since the request is not in scheduler
472498
# Send result by zmq directly
473-
self.zmq_server.send_multipart(request_id, error_result)
499+
self.send_response_server.send_response(request_id, error_result)
474500
except Exception as e:
475501
llm_logger.error(
476502
f"Error happend while receving new request from zmq, details={e}, "
@@ -989,8 +1015,12 @@ def _exit_sub_services(self):
9891015
print(f"Error extracting sub services: {e}")
9901016

9911017
self.engine_worker_queue.cleanup()
992-
if hasattr(self, "zmq_server") and self.zmq_server is not None:
993-
self.zmq_server.close()
1018+
if hasattr(self, "send_response_server") and self.send_response_server is not None:
1019+
self.send_response_server.close()
1020+
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
1021+
self.recv_request_server.close()
1022+
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
1023+
self.recv_control_cmd_server.close()
9941024
if hasattr(self, "dp_processed"):
9951025
for p in self.dp_processed:
9961026
p.join()

fastdeploy/engine/expert_service.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
from fastdeploy.inter_communicator import EngineWorkerQueue
3030
from fastdeploy.metrics.metrics import main_process_metrics
3131
from fastdeploy.output.token_processor import TokenProcessor
32+
from fastdeploy.splitwise.internal_adapter_utils import InternalAdapter
3233
from fastdeploy.splitwise.splitwise_connector import SplitwiseConnector
33-
from fastdeploy.utils import EngineError, console_logger, llm_logger
34+
from fastdeploy.utils import EngineError, console_logger, envs, llm_logger
3435

3536

3637
class ExpertService:
@@ -60,7 +61,8 @@ def __init__(self, cfg, local_data_parallel_id):
6061

6162
self.scheduler = cfg.scheduler_config.scheduler()
6263

63-
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
64+
if self.cfg.scheduler_config.name == "splitwise":
65+
self.scheduler.reset_nodeid(f"{self.scheduler.infer.nodeid}_{local_data_parallel_id!s}")
6466

6567
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
6668

@@ -111,8 +113,12 @@ def __init__(self, cfg, local_data_parallel_id):
111113
)
112114

113115
self._finalizer = weakref.finalize(self, self._exit_sub_services)
116+
if envs.FD_ENABLE_INTERNAL_ADAPTER:
117+
self.external_adapter = InternalAdapter(cfg=self.cfg, engine=self, dp_rank=local_data_parallel_id)
114118

115-
def start(self, ipc_signal_suffix, local_data_parallel_id):
119+
def start(
120+
self, ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
121+
):
116122
"""
117123
Initializes the engine and starts its sub-services.
118124
If `api_server_pid` is defined, will launch a thread
@@ -127,7 +133,7 @@ def start(self, ipc_signal_suffix, local_data_parallel_id):
127133
cache_config=self.cfg.cache_config,
128134
tensor_parallel_size=self.cfg.tensor_parallel_size,
129135
device_ids=self.cfg.local_device_ids,
130-
pod_ip=self.cfg.pod_ips[0],
136+
pod_ip=self.cfg.master_ip,
131137
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
132138
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
133139
)
@@ -147,7 +153,11 @@ def start(self, ipc_signal_suffix, local_data_parallel_id):
147153
role = self.cfg.splitwise_role
148154
host_ip = self.cfg.host_ip
149155
disaggregate = self.cfg.disaggregate_info
150-
self.scheduler.start(role, host_ip, disaggregate)
156+
if self.cfg.scheduler_config.name == "dp":
157+
assert (request_queues_for_dp_ipc is not None) and (result_queue_for_dp_ipc is not None)
158+
self.scheduler.start(local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc)
159+
elif self.cfg.scheduler_config.name == "splitwise":
160+
self.scheduler.start(role, host_ip, disaggregate)
151161
self.cfg.print()
152162

153163
console_logger.info(f"Worker processes are launched with {time.time() - start_time} seconds.")
@@ -356,13 +366,17 @@ def _exit_sub_services(self):
356366
self.zmq_server.close()
357367

358368

359-
def start_expert_service(cfg, local_data_parallel_id, ipc_signal_suffix):
369+
def start_expert_service(
370+
cfg, local_data_parallel_id, ipc_signal_suffix, request_queues_for_dp_ipc=None, result_queue_for_dp_ipc=None
371+
):
360372
"""
361373
Start expert service
362374
"""
363375
expert_service = ExpertService(cfg, local_data_parallel_id)
364376
try:
365-
expert_service.start(ipc_signal_suffix, local_data_parallel_id)
377+
expert_service.start(
378+
ipc_signal_suffix, local_data_parallel_id, request_queues_for_dp_ipc, result_queue_for_dp_ipc
379+
)
366380
expert_service.split_connector.start_receiver()
367381
except Exception as e:
368382
llm_logger.exception(f"Expert service failed to start: {e}")

fastdeploy/engine/request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
guided_json_object: Optional[bool] = None,
7272
enable_thinking: Optional[bool] = True,
7373
trace_carrier: dict = dict(),
74+
dp_rank: Optional[int] = None
7475
) -> None:
7576
self.request_id = request_id
7677
self.prompt = prompt
@@ -119,6 +120,7 @@ def __init__(
119120
self.task_type = RequestType.PREFILL
120121
self.idx = None
121122
self.need_prefill_tokens = self.prompt_token_ids_len
123+
self.dp_rank = dp_rank
122124

123125
@classmethod
124126
def from_dict(cls, d: dict):
@@ -151,6 +153,7 @@ def from_dict(cls, d: dict):
151153
guided_json_object=d.get("guided_json_object", None),
152154
enable_thinking=d.get("enable_thinking", True),
153155
trace_carrier=d.get("trace_carrier", {}),
156+
dp_rank=d.get("dp_rank", None)
154157
)
155158

156159
@property

fastdeploy/entrypoints/engine_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from fastdeploy.engine.config import ModelConfig
2323
from fastdeploy.input.preprocess import InputPreprocessor
24-
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
24+
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient
2525
from fastdeploy.metrics.work_metrics import work_process_metrics
2626
from fastdeploy.multimodal.registry import MultimodalRegistry
2727
from fastdeploy.platforms import current_platform
@@ -90,7 +90,7 @@ def create_zmq_client(self, model, mode):
9090
"""
9191
Create a ZMQ client.
9292
"""
93-
self.zmq_client = ZmqClient(model, mode)
93+
self.zmq_client = ZmqIpcClient(model, mode)
9494
self.zmq_client.connect()
9595

9696
def format_and_add_data(self, prompts: dict):

0 commit comments

Comments
 (0)