Skip to content

Commit 7573802

Browse files
authored
[Feature] Support mtp ep in fd (#3340)
* [Optimize] Add metrics for analysing perf * Fix bug in mtp
1 parent 110f33a commit 7573802

File tree

5 files changed

+24
-8
lines changed

5 files changed

+24
-8
lines changed

fastdeploy/engine/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -866,10 +866,10 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
866866
is_prefill = True
867867
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
868868

869+
for task in tasks:
870+
task.inference_start_time = time.time()
869871
if not is_decode:
870872
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
871-
for task in tasks:
872-
task.inference_start_time = time.time()
873873
if not self.cfg.enable_mm:
874874
self.update_requests_chunk_size(tasks)
875875
else:

fastdeploy/engine/expert_service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import copy
1920
import os
2021
import signal
2122
import threading
@@ -293,6 +294,9 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
293294
cur_task_idx = self.resource_manager.req_dict[task.request_id]
294295
del self.resource_manager.req_dict[task.request_id]
295296
cur_task = self.resource_manager.tasks_list[cur_task_idx]
297+
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
298+
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
299+
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
296300
if task.error_code != 200:
297301
self.resource_manager.stop_flags[cur_task_idx] = True
298302
self.resource_manager.tasks_list[cur_task_idx] = None

fastdeploy/output/token_processor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,14 @@ def process_sampling_results(self):
195195
try:
196196
is_blocking = True
197197
if self.speculative_decoding:
198-
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
198+
if (
199+
self.cfg.parallel_config.enable_expert_parallel
200+
and self.cfg.parallel_config.data_parallel_size > 1
201+
):
202+
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
203+
else:
204+
205+
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
199206
if self.output_tokens[0] == -2:
200207
continue
201208

@@ -478,6 +485,7 @@ def _process_batch_output(self):
478485
arrival_time=task.arrival_time,
479486
inference_start_time=task.inference_start_time,
480487
first_token_time=time.time() - task.inference_start_time,
488+
model_execute_time=time.time() - task.inference_start_time,
481489
time_in_queue=task.schedule_start_time - task.preprocess_end_time,
482490
preprocess_cost_time=task.preprocess_end_time - task.preprocess_start_time,
483491
request_start_time=task.arrival_time,
@@ -489,6 +497,7 @@ def _process_batch_output(self):
489497
metrics = RequestMetrics(
490498
arrival_time=time.time(),
491499
request_start_time=task.arrival_time,
500+
model_execute_time=time.time() - task.inference_start_time,
492501
)
493502
self.number_of_output_tokens += len(token_ids)
494503
self._record_metrics(task, current_time, token_ids)
@@ -506,7 +515,7 @@ def _process_batch_output(self):
506515
if self.tokens_counter[task_id] == 0:
507516
if task.messages is not None:
508517
result.prompt = task.messages
509-
result.num_cached_tokens = task.num_cached_tokens
518+
result.num_cached_tokens = task.num_cached_tokens
510519

511520
is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
512521

fastdeploy/spec_decode/mtp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _update_cfg(self, main_model):
7676
self.model_config.num_hidden_layers = 1
7777
self.model_config.model = self.speculative_config.model
7878
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
79+
self.model_config.is_quantized = False
7980
if self.speculative_config.quantization != "":
8081
self.model_config.quantization = self.speculative_config.quantization
8182
self.model_config.start_layer_index = self.num_main_model_layers
@@ -142,15 +143,16 @@ def initialize_kv_cache(self):
142143
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
143144
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
144145
)
146+
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
145147
if not self.parallel_config.do_profile and self.parallel_config.splitwise_role != "mixed":
146148
cache_kvs_list = []
147149
for i in range(
148150
self.num_main_model_layers,
149151
self.num_main_model_layers + self.model_config.num_hidden_layers,
150152
):
151153
key_cache = paddle.empty(shape=[], dtype=cache_type)
152-
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
153-
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
154+
key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}"
155+
val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}"
154156
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
155157
cache_kvs_list.append(key_cache)
156158
value_cache = paddle.empty(shape=[], dtype=cache_type)
@@ -176,11 +178,11 @@ def initialize_kv_cache(self):
176178
if self.cache_config.enable_prefix_caching:
177179
set_data_ipc(
178180
self.cache_kvs[f"key_caches_{i}"],
179-
f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}",
181+
f"key_caches_{i}_rank{local_rank}.device{self.device_id}",
180182
)
181183
set_data_ipc(
182184
self.cache_kvs[f"value_caches_{i}"],
183-
f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}",
185+
f"value_caches_{i}_rank{local_rank}.device{self.device_id}",
184186
)
185187
self.model_inputs["caches"] = list(self.cache_kvs.values())
186188
for value in self.cache_kvs.values():

fastdeploy/splitwise/splitwise_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ def _handle_decode(self, payload):
503503
index=task["outputs"]["index"],
504504
send_idx=0,
505505
token_ids=task["outputs"]["token_ids"],
506+
draft_token_ids=task["outputs"]["draft_token_ids"],
506507
),
507508
finished=True,
508509
)

0 commit comments

Comments
 (0)