Skip to content

Commit b20ffe3

Browse files
authored
[Feature] optimize expert parallel (#3196)
* optimize * Update expert_service.py * Update worker_process.py * optimize
1 parent dcf9c2d commit b20ffe3

File tree

7 files changed

+174
-134
lines changed

7 files changed

+174
-134
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
self.gpu_cache_kvs = gpu_cache_kvs
128128
self.rank = rank
129129
self.nranks = nranks
130-
address = (pod_ip, engine_worker_queue_port)
130+
address = (pod_ip, engine_worker_queue_port + local_data_parallel_id)
131131
self.engine_worker_queue = EngineWorkerQueue(
132132
address=address,
133133
is_server=False,

fastdeploy/engine/engine.py

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import traceback
2929
import uuid
3030
import weakref
31+
from collections import deque
3132
from concurrent.futures import ThreadPoolExecutor
3233
from typing import Dict, List, Optional, Tuple
3334

@@ -125,9 +126,17 @@ def __init__(self, cfg):
125126
cfg.max_num_seqs, cfg, cfg.tensor_parallel_size, cfg.splitwise_role
126127
)
127128

128-
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.cfg.engine_worker_queue_port)
129-
130-
self.split_connector = SplitwiseConnector(cfg, self.scheduler, self.engine_worker_queue, self.resource_manager)
129+
os.environ["INFERENCE_MSG_QUEUE_ID"] = str(
130+
self.cfg.engine_worker_queue_port + self.cfg.parallel_config.local_data_parallel_id
131+
)
132+
self.splitwise_queue = deque()
133+
self.split_connector = SplitwiseConnector(
134+
cfg,
135+
self.scheduler,
136+
self.engine_worker_queue,
137+
self.resource_manager,
138+
self.splitwise_queue,
139+
)
131140

132141
self.token_processor = TokenProcessor(
133142
cfg=self.cfg,
@@ -343,12 +352,6 @@ def _insert_task_to_worker(self):
343352
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
344353
time.sleep(0.005)
345354
continue
346-
if self.engine_worker_queue.num_cache_infos() > 0:
347-
time.sleep(0.001)
348-
continue
349-
if len(self.split_connector.current_request_ids) > 0:
350-
time.sleep(0.001)
351-
continue
352355

353356
num_prefill_batch = min(
354357
int(self.resource_manager.available_batch()),
@@ -596,43 +599,42 @@ def receiver_loop():
596599
for idx in sorted(processed_indices, reverse=True):
597600
self.waiting_requests.pop(idx)
598601

599-
if not self.engine_worker_queue.disaggregate_queue_empty():
600-
items = self.engine_worker_queue.get_disaggregated_tasks()
601-
for item in items:
602-
role = item[0]
603-
tasks = item[1]
602+
if len(self.splitwise_queue) > 0:
603+
items = self.splitwise_queue.pop()
604+
role = items[0]
605+
tasks = items[1]
604606

605-
if role == "prefill":
606-
for task in tasks:
607-
task.max_tokens = task.min_tokens = 2
608-
self.insert_tasks(tasks)
607+
if role == "prefill":
608+
for task in tasks:
609+
task.max_tokens = task.min_tokens = 2
610+
self.insert_tasks(tasks)
609611

610-
elif role == "decode":
611-
if hasattr(tasks[0], "finished"):
612-
if not isinstance(tasks, list):
613-
tasks = [tasks]
614-
for task in tasks:
615-
task.finished = False
616-
self.insert_tasks(tasks, allocated=True)
612+
elif role == "decode":
613+
if hasattr(tasks[0], "finished"):
614+
if not isinstance(tasks, list):
615+
tasks = [tasks]
616+
for task in tasks:
617+
task.finished = False
618+
self.insert_tasks(tasks, allocated=True)
617619

618-
if self.cfg.innode_prefill_ports is not None:
619-
self.scheduler.put_results(tasks)
620+
if self.cfg.innode_prefill_ports is not None:
621+
self.scheduler.put_results(tasks)
620622

623+
else:
624+
if len(self.waiting_requests):
625+
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
626+
self.waiting_requests.extend(tasks)
621627
else:
622-
if len(self.waiting_requests):
623-
llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
624-
self.waiting_requests.extend(tasks)
625-
else:
626-
new_waiting = []
627-
for task in tasks:
628-
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
629-
self.insert_tasks([task])
630-
else:
631-
new_waiting.append(task)
632-
633-
if new_waiting:
634-
self.waiting_requests.extend(new_waiting)
635-
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
628+
new_waiting = []
629+
for task in tasks:
630+
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
631+
self.insert_tasks([task])
632+
else:
633+
new_waiting.append(task)
634+
635+
if new_waiting:
636+
self.waiting_requests.extend(new_waiting)
637+
llm_logger.info(f"Added {len(new_waiting)} tasks to waiting queue")
636638

637639
else:
638640
time.sleep(0.001)
@@ -842,7 +844,6 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
842844
is_prefill = True
843845
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
844846

845-
self.split_connector.send_cache_infos(tasks, current_id)
846847
if not is_decode:
847848
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
848849
for task in tasks:
@@ -854,6 +855,8 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
854855
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
855856
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
856857
self.engine_worker_queue.available_prefill_instances.put(1)
858+
859+
self.split_connector.send_cache_infos(tasks, current_id)
857860
return True
858861

859862
def task_is_finished(self, index):
@@ -1017,13 +1020,16 @@ def _exit_sub_services(self):
10171020
except Exception as e:
10181021
print(f"Error extracting sub services: {e}")
10191022

1020-
self.engine_worker_queue.cleanup()
1023+
1024+
for worker_queue in self.engine_worker_queue_server:
1025+
worker_queue.cleanup()
10211026
if hasattr(self, "send_response_server") and self.send_response_server is not None:
10221027
self.send_response_server.close()
10231028
if hasattr(self, "recv_request_server") and self.recv_request_server is not None:
10241029
self.recv_request_server.close()
10251030
if hasattr(self, "recv_control_cmd_server") and self.recv_control_cmd_server is not None:
10261031
self.recv_control_cmd_server.close()
1032+
10271033
if hasattr(self, "dp_processed"):
10281034
for p in self.dp_processed:
10291035
p.join()
@@ -1325,15 +1331,20 @@ def start_queue_service(self):
13251331
"""
13261332
start queue service for engine worker communication
13271333
"""
1328-
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
1334+
1335+
self.engine_worker_queue_server = list()
13291336
if self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0":
1330-
llm_logger.info(f"Starting engine worker queue server service at {address}")
1331-
self.engine_worker_queue_server = EngineWorkerQueue(
1332-
address=address,
1333-
is_server=True,
1334-
num_client=self.cfg.tensor_parallel_size,
1335-
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
1336-
)
1337+
for i in range(self.cfg.parallel_config.data_parallel_size // self.cfg.nnode):
1338+
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port + i)
1339+
llm_logger.info(f"Starting engine worker queue service at {address}")
1340+
self.engine_worker_queue_server.append(
1341+
EngineWorkerQueue(
1342+
address=address,
1343+
is_server=True,
1344+
num_client=self.cfg.tensor_parallel_size,
1345+
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
1346+
)
1347+
)
13371348

13381349
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
13391350
self.cache_task_queue = EngineCacheQueue(
@@ -1348,6 +1359,7 @@ def start_queue_service(self):
13481359
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
13491360
)
13501361

1362+
address = (self.cfg.master_ip, self.cfg.engine_worker_queue_port)
13511363
self.engine_worker_queue = EngineWorkerQueue(
13521364
address=address,
13531365
is_server=False,

0 commit comments

Comments
 (0)