47
47
EngineCacheQueue ,
48
48
EngineWorkerQueue ,
49
49
IPCSignal ,
50
- ZmqClient ,
50
+ ZmqIpcServer ,
51
+ ZmqTcpServer ,
51
52
)
52
53
from fastdeploy .metrics .metrics import main_process_metrics
53
54
from fastdeploy .metrics .trace_util import start_span , start_span_request
54
55
from fastdeploy .model_executor .guided_decoding import schema_checker
55
56
from fastdeploy .output .token_processor import TokenProcessor , WarmUpTokenProcessor
57
+ from fastdeploy .splitwise .internal_adapter_utils import InternalAdapter
56
58
from fastdeploy .splitwise .splitwise_connector import SplitwiseConnector
57
59
from fastdeploy .utils import EngineError , console_logger , envs , llm_logger
58
60
@@ -179,11 +181,64 @@ def start(self, api_server_pid=None):
179
181
self .data_processor = self .input_processor .create_processor ()
180
182
181
183
if api_server_pid is not None :
182
- self .zmq_server = ZmqClient (name = api_server_pid , mode = zmq .PULL )
183
- self .zmq_server .start_server ()
184
- self .zmq_server .create_router ()
184
+ if envs .FD_ENABLE_INTERNAL_ADAPTER :
185
+ self .recv_request_server = ZmqTcpServer (port = envs .FD_ZMQ_RECV_REQUEST_SERVER_PORT , mode = zmq .PULL )
186
+ self .send_response_server = ZmqTcpServer (port = envs .FD_ZMQ_SEND_RESPONSE_SERVER_PORT , mode = zmq .ROUTER )
187
+ self .external_adapter = InternalAdapter (
188
+ cfg = self .cfg , engine = self , dp_rank = self .cfg .node_rank * self .cfg .worker_num_per_node
189
+ )
190
+ else :
191
+ self .recv_request_server = ZmqIpcServer (name = api_server_pid , mode = zmq .PULL )
192
+ self .send_response_server = ZmqIpcServer (name = api_server_pid , mode = zmq .ROUTER )
185
193
time .sleep (3 )
186
194
195
+ self .cfg .init_cache_info ()
196
+
197
+ role = self .cfg .splitwise_role
198
+ host_ip = self .cfg .host_ip
199
+ disaggregate = self .cfg .disaggregate_info
200
+ request_queues_for_dp_ipc = (
201
+ None # Different dp has its own process, use multiprocessing.Queue to deliver requests for each dp
202
+ )
203
+ result_queue_for_dp_ipc = None
204
+ if self .cfg .scheduler_config .name == "splitwise" :
205
+ self .scheduler .start (role , host_ip , disaggregate )
206
+ elif self .cfg .scheduler_config .name == "dp" :
207
+ request_queues_for_dp_ipc = []
208
+ result_queue_for_dp_ipc = multiprocessing .Queue ()
209
+ for i in range (self .cfg .parallel_config .data_parallel_size ):
210
+ request_queues_for_dp_ipc .append (multiprocessing .Queue ())
211
+ self .scheduler .start (
212
+ self .cfg .node_rank * self .cfg .worker_num_per_node , request_queues_for_dp_ipc , result_queue_for_dp_ipc
213
+ )
214
+
215
+ time .sleep (1 )
216
+
217
+ if self .cfg .parallel_config .enable_expert_parallel and self .cfg .parallel_config .data_parallel_size > 1 :
218
+ self .dp_processed = []
219
+ for i in range (
220
+ 1 ,
221
+ self .cfg .parallel_config .data_parallel_size // self .cfg .nnode ,
222
+ ):
223
+ time .sleep (1 )
224
+ self .dp_processed .append (
225
+ multiprocessing .Process (
226
+ target = start_expert_service ,
227
+ args = (
228
+ self .cfg ,
229
+ i + self .cfg .node_rank * self .cfg .worker_num_per_node ,
230
+ self .ipc_signal_suffix ,
231
+ request_queues_for_dp_ipc ,
232
+ result_queue_for_dp_ipc ,
233
+ ),
234
+ )
235
+ )
236
+ llm_logger .info (
237
+ f"Engine is initialized successfully with { self .cfg .tensor_parallel_size } "
238
+ + f" data parallel id { i } "
239
+ )
240
+ self .dp_processed [- 1 ].start ()
241
+
187
242
if self .do_profile == 0 and (
188
243
self .cfg .cache_config .enable_prefix_caching or self .cfg .splitwise_role != "mixed"
189
244
):
@@ -238,44 +293,11 @@ def start(self, api_server_pid=None):
238
293
# 单机逻辑
239
294
self .engine_worker_queue .available_prefill_instances .put (1 )
240
295
self .split_mode_get_tasks ()
241
- if self .cfg .scheduler_config .name == "splitwise" :
296
+ if self .cfg .scheduler_config .name == "splitwise" or self . cfg . scheduler_config . name == "dp" :
242
297
self .splitwise_receive_thread = threading .Thread (target = self .split_connector .start_receiver , args = ())
243
298
self .splitwise_receive_thread .daemon = True
244
299
self .splitwise_receive_thread .start ()
245
300
246
- self .cfg .init_cache_info ()
247
-
248
- role = self .cfg .splitwise_role
249
- host_ip = self .cfg .host_ip
250
- disaggregate = self .cfg .disaggregate_info
251
- if self .cfg .scheduler_config .name == "splitwise" :
252
- self .scheduler .start (role , host_ip , disaggregate )
253
-
254
- time .sleep (1 )
255
-
256
- if self .cfg .parallel_config .enable_expert_parallel and self .cfg .parallel_config .data_parallel_size > 1 :
257
- self .dp_processed = []
258
- for i in range (
259
- 1 ,
260
- self .cfg .parallel_config .data_parallel_size // self .cfg .nnode ,
261
- ):
262
- time .sleep (1 )
263
- self .dp_processed .append (
264
- multiprocessing .Process (
265
- target = start_expert_service ,
266
- args = (
267
- self .cfg ,
268
- i + self .cfg .node_rank * self .cfg .worker_num_per_node ,
269
- self .ipc_signal_suffix ,
270
- ),
271
- )
272
- )
273
- llm_logger .info (
274
- f"Engine is initialized successfully with { self .cfg .tensor_parallel_size } "
275
- + f" data parallel id { i } "
276
- )
277
- self .dp_processed [- 1 ].start ()
278
-
279
301
console_logger .info (f"Worker processes are launched with { time .time () - start_time } seconds." )
280
302
return True
281
303
@@ -291,7 +313,7 @@ def _zmq_send_generated_tokens(self):
291
313
time .sleep (0.005 )
292
314
continue
293
315
for request_id , contents in results .items ():
294
- self .zmq_server . send_multipart (request_id , contents )
316
+ self .send_response_server . send_response (request_id , contents )
295
317
296
318
except Exception as e :
297
319
llm_logger .error (f"Unexcepted error happend: { e } , { traceback .format_exc ()!s} " )
@@ -415,14 +437,18 @@ def _insert_zmq_task_to_scheduler(self):
415
437
if self .api_server_pid is None :
416
438
return
417
439
440
+ if envs .FD_ENABLE_INTERNAL_ADAPTER :
441
+ if self .cfg .splitwise_role == "decode" :
442
+ return
443
+
418
444
added_requests : Dict [str , int ] = dict ()
419
445
while self .running :
420
446
try :
421
447
block = True if len (added_requests ) == 0 else False
422
448
if not self .cfg .enable_mm :
423
- err , data = self .zmq_server .receive_json_once (block )
449
+ err , data = self .recv_request_server .receive_json_once (block )
424
450
else :
425
- err , data = self .zmq_server .receive_pyobj_once (block )
451
+ err , data = self .recv_request_server .receive_pyobj_once (block )
426
452
if err is not None :
427
453
llm_logger .error ("Engine stops inserting zmq task into scheduler, err:{err}" )
428
454
break
@@ -470,7 +496,7 @@ def _insert_zmq_task_to_scheduler(self):
470
496
)
471
497
# Since the request is not in scheduler
472
498
# Send result by zmq directly
473
- self .zmq_server . send_multipart (request_id , error_result )
499
+ self .send_response_server . send_response (request_id , error_result )
474
500
except Exception as e :
475
501
llm_logger .error (
476
502
f"Error happend while receving new request from zmq, details={ e } , "
@@ -989,8 +1015,12 @@ def _exit_sub_services(self):
989
1015
print (f"Error extracting sub services: { e } " )
990
1016
991
1017
self .engine_worker_queue .cleanup ()
992
- if hasattr (self , "zmq_server" ) and self .zmq_server is not None :
993
- self .zmq_server .close ()
1018
+ if hasattr (self , "send_response_server" ) and self .send_response_server is not None :
1019
+ self .send_response_server .close ()
1020
+ if hasattr (self , "recv_request_server" ) and self .recv_request_server is not None :
1021
+ self .recv_request_server .close ()
1022
+ if hasattr (self , "recv_control_cmd_server" ) and self .recv_control_cmd_server is not None :
1023
+ self .recv_control_cmd_server .close ()
994
1024
if hasattr (self , "dp_processed" ):
995
1025
for p in self .dp_processed :
996
1026
p .join ()
0 commit comments