Skip to content

Commit e2c764f

Browse files
authored
update hybrid-mtp-with-ngram (#3924)
1 parent 2d975e1 commit e2c764f

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

docs/features/speculative_decoding.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ This project implements an efficient **Speculative Decoding** inference framewor
1818
- ⏳ Coming Soon: Support Chunk-prefill
1919
- ⏳ Coming Soon: Multi-layer MTP Layer
2020

21+
- **Decoding with Hybrid MTP and Ngram Methods(Hybrid-MTP-with-Ngram)**
22+
23+
- Overview: A hybrid method combining MTP and Ngram. First, MTP generates N draft tokens, then Ngram matching is used to supplement additional draft tokens.
24+
25+
- Use Cases: Suitable when higher draft token coverage is required, leveraging both MTP’s generation capability and the efficiency of Ngram matching.
26+
27+
2128
---
2229

2330
### Coming Soon
@@ -132,7 +139,13 @@ python -m fastdeploy.entrypoints.openai.api_server \
132139
--scheduler-password "scheduler_mtp" \
133140
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": "${path_to_mtp_model}"}' &
134141
```
142+
## Decoding with Hybrid MTP and Ngram Methods
135143

144+
When starting the service, you only need to modify the --speculative-config option.
145+
For example, use MTP to generate two draft tokens, and then append three additional draft tokens from Ngram matching:
146+
```
147+
--speculative-config '{"method": "mtp", "num_model_steps": 2, "mtp_strategy": "with_ngram", "num_speculative_tokens": 5, "model": "'$model_path'/mtp"}'
148+
```
136149
## 🧠 Using Ngram-Based Decoding
137150
This method uses an n-gram sliding window to match the prompt and generated tokens to predict draft tokens. It is particularly effective in scenarios with high input-output overlap (e.g., code completion, document search).
138151

docs/zh/features/speculative_decoding.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
- ⏳ 即将支持:兼容 Chunk Prefill
1515
- ⏳ 即将支持:多层 MTP layer
1616

17+
- **混合MTP、Ngram方法解码(Hybrid-MTP-with-Ngram)**
18+
- 方法概述:混合MTP与Ngram方法,先使用MTP产出N个草稿Token,再使用Ngram匹配补充草稿Token。
19+
- 使用场景:适合在需要更多草稿Token时使用,兼顾MTP生成能力与Ngram匹配的高效性。
1720
---
1821

1922
### ⏳ 规划中
@@ -110,7 +113,12 @@ python -m fastdeploy.entrypoints.openai.api_server \
110113
--scheduler-password "scheduler_mtp" \
111114
--speculative-config '{"method": "mtp", "num_speculative_tokens": 1, "model": ""${path_to_mtp_model}"}' &
112115
```
116+
## 使用混合MTP、Ngram方法解码
117+
在启动服务时,只需改动 --speculative-config 即可。例如使用MTP产出两个DraftToken,再额外拼接三个Ngram匹配的DraftToken
118+
```
119+
--speculative-config '{"method": "mtp", "num_model_steps": 2, "mtp_strategy": "with_ngram" ,"num_speculative_tokens": 5, "model": "'$model_path'/mtp"}'
113120
121+
```
114122
## 🧠 使用 Ngram 解码
115123
该算法通过 n-gram 窗口从 prompt 和已生成的 Token 中进行匹配生成草稿 Token,适合输入和输出有很大 overlap 的场景,如代码续写、文档查询等。
116124
> 使用 4×H100;量化方式选择 WINT4

fastdeploy/spec_decode/mtp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def _init_model_inputs(self):
268268
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
269269
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
270270
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
271-
271+
self.model_inputs["input_ids_cpu"] = paddle.full(
272+
shape=[self.max_num_seqs, self.parallel_config.max_model_len],
273+
fill_value=-1,
274+
dtype="int64",
275+
).cpu()
272276
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
273277
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
274278
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
@@ -368,10 +372,17 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
368372
request = req_dicts[i]
369373
idx = request.idx
370374
length = len(request.prompt_token_ids)
371-
self.input_ids_len[idx] = length
375+
self.input_ids_len[idx] = length - 1
372376

373377
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
374378
length = len(request.prompt_token_ids)
379+
if length > 1:
380+
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
381+
"input_ids"
382+
][idx : idx + 1, 1:length]
383+
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
384+
request.prompt_token_ids
385+
)[1:]
375386
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
376387
prefill_token_num = self.max_draft_token_num + 1
377388
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
@@ -400,6 +411,10 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
400411
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][
401412
idx : idx + 1, 1:length
402413
]
414+
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
415+
request.prompt_token_ids
416+
)[1:]
417+
403418
self.model_inputs["pre_ids"][idx : idx + 1] = -1
404419
self.model_inputs["step_idx"][idx : idx + 1] = 0
405420
if self.cache_config.enable_chunked_prefill:
@@ -688,7 +703,7 @@ def _extend_draft_token_with_ngram_match(self):
688703
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
689704
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
690705
hybrid_mtp_ngram(
691-
self.model_inputs["input_ids"]._copy_to(device, True),
706+
self.model_inputs["input_ids_cpu"],
692707
self.input_ids_len,
693708
self.model_inputs["pre_ids"]._copy_to(device, True),
694709
self.model_inputs["step_idx"].cpu(),

0 commit comments

Comments
 (0)