Skip to content

Commit 82d3587

Browse files
authored
[refactor] Unify name of NGram speculative decoding (#5937)
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com> Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
1 parent 152e2df commit 82d3587

File tree

15 files changed

+140
-143
lines changed

15 files changed

+140
-143
lines changed

docs/source/advanced/speculative-decoding.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
- [About Speculative Sampling](#about-speculative-sampling)
44
- [Performance Improvements](#Performance-improvements)
55
- [Draft-Target-Model](#Draft-Target-Model)
6-
- [Prompt-Lookup-Decoding](#prompt-lookup-decoding)
6+
- [NGram](#ngram)
77
- [Medusa](#medusa)
88
- [Medusa Tree](#medusa-tree)
99
- [Using Medusa with TensorRT-LLM](#using-medusa-with-tensorrt-llm)
@@ -36,7 +36,7 @@ TensorRT-LLM supports several approaches for generating draft tokens, including:
3636
1. [Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads paper](https://arxiv.org/abs/2401.10774).
3737
2. [Recurrent Drafter for Fast Speculative Decoding in Large Language Models](https://arxiv.org/html/2403.09919v1).
3838
3. [EAGLE: Speculative Sampling Requires Rethinking Feature Uncertainty](https://arxiv.org/pdf/2401.15077).
39-
3. Utilizing prompt tokens as draft tokens. For more information, refer to [Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding/).
39+
3. Utilizing prompt tokens as draft tokens. For more information, refer to [NGram](https://github.com/apoorvumang/prompt-lookup-decoding/).
4040
4. Utilizing Jacobi-like decoding to predict and verify draft tokens using the same model which does not need additional fine-tuning. Refer to [Break the Sequential Dependency of LLM Inference Using Lookahead Decoding](https://arxiv.org/pdf/2402.02057).
4141

4242

@@ -62,13 +62,13 @@ Subsequently, the prompt, now updated with the accepted tokens, is sent back to
6262
This iterative process continues until a predefined stop conditions are met.
6363
An example of this orchestration process can be found in the [TensorRT-LLM Triton backend](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/inflight_batcher_llm/client/e2e_grpc_speculative_decoding_client.py).
6464

65-
We provide two styles of running Draft-Target-Model now: using TensorRT-LLM-BLS in Triton Inference Server, or using TensorRT-LLM directly. Detailed steps of running can be found in [examples/draft_target_model/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/draft_target_model/README.md) and the code can be found in [examples/prompt_lookup/run_dtm_pld.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/prompt_lookup/run_dtm_pld.py).
65+
We provide two styles of running Draft-Target-Model now: using TensorRT-LLM-BLS in Triton Inference Server, or using TensorRT-LLM directly. Detailed steps of running can be found in [examples/draft_target_model/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/draft_target_model/README.md) and the code can be found in [examples/ngram/run_dtm_ngram.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/ngram/run_dtm_ngram.py).
6666

67-
## Prompt-Lookup-Decoding
67+
## NGram
6868

69-
The Prompt-Lookup speculative decoding directly copies from the input prompt and previous generated output as draft tokens while generating the later output. It works like Draft-Target-Model but involves only one Target LLM model without further fine-tuning. The Prompt-Lookup profit from the scenarios which have high n-gram overlap between input prompt and output, such as summarization, document QA, multi-turn chat, code editing, etc.
69+
The NGram speculative decoding directly copies from the input prompt and previous generated output as draft tokens while generating the later output. It works like Draft-Target-Model but involves only one Target LLM model without further fine-tuning. The NGram profit from the scenarios which have high n-gram overlap between input prompt and output, such as summarization, document QA, multi-turn chat, code editing, etc.
7070

71-
See document in [examples/prompt_lookup/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/prompt_lookup/README.md) and the code can be found in [examples/prompt_lookup/run_dtm_pld.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/prompt_lookup/run_dtm_pld.py).
71+
See document in [examples/ngram/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/ngram/README.md) and the code can be found in [examples/ngram/run_dtm_ngram.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/ngram/run_dtm_ngram.py).
7272

7373
## Medusa
7474

examples/llm-api/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
4040
python3 quickstart_advanced.py \
4141
--model_dir meta-llama/Llama-3.1-8B-Instruct \
4242
--spec_decode_algo NGRAM \
43-
--max_matching_ngram_size=2 \
44-
--spec_decode_nextn=4 \
45-
--disable_overlap_scheduler
43+
--spec_decode_nextn 4 \
44+
--max_matching_ngram_size 2 \
45+
--disable_overlap_scheduler \
46+
--disable_kv_cache_reuse
4647
```
4748

4849
```bash
@@ -52,6 +53,6 @@ python3 quickstart_advanced.py \
5253
--spec_decode_algo draft_target \
5354
--spec_decode_nextn 5 \
5455
--draft_model_dir meta-llama/Llama-3.2-1B-Instruct \
55-
--disable_overlap_scheduler
56+
--disable_overlap_scheduler \
5657
--disable_kv_cache_reuse
5758
```
Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
# Prompt-Lookup Speculative Decoding
1+
# NGram Speculative Decoding
22

3-
This document shows how to build and run a model using Prompt-Lookup speculative decoding (supported as `ASSISTED_GENERATION` in transformers and vLLM, source: [GitHub](https://github.com/apoorvumang/prompt-lookup-decoding/tree/main)) in TensorRT-LLM on single GPU, or single node multiple GPU.
3+
This document shows how to build and run a model using NGram speculative decoding (supported as `ASSISTED_GENERATION` in transformers and vLLM, source: [GitHub](https://github.com/apoorvumang/prompt-lookup-decoding/tree/main)) in TensorRT-LLM on single GPU, or single node multiple GPU.
44

55
## Overview
66

7-
We provide two styles of workflow to run Prompt-Lookup (named V1 and V2 respectively) now. V1 is in TRT workflow and similar to the Draft-Target-Model workflow, running in orchestrator mode and calling `runner.generate()` multiple times to get outputs, which is more flexible for customizing but slightly more overhead. V2 is in pytorch workflow and similar to the Look-Ahead workflow, running in leader mode and calling `runner.generate()` only one time to get outputs, which provides higher performance but fixed process.
7+
We provide two styles of workflow to run NGram (named V1 and V2 respectively) now. V1 is in TRT workflow and similar to the Draft-Target-Model workflow, running in orchestrator mode and calling `runner.generate()` multiple times to get outputs, which is more flexible for customizing but slightly more overhead. V2 is in pytorch workflow and similar to the Look-Ahead workflow, running in leader mode and calling `runner.generate()` only one time to get outputs, which provides higher performance but fixed process.
88

9-
The Prompt-Lookup has 3 additional hyperparameters that you need to specify to control the process of generation:
10-
- `prompt_lookup_num_tokens`: the maximum number of tokens provided as draft tokens in one iteration, which is usually from 4 to 10 in common usage (default value: 4). Empirically, the larger the value is, the higher acceptance rate but higher overhead is expected at the same time, so the right balance based on the models and application scenarios needs to be found.
9+
The NGram has 3 additional hyperparameters that you need to specify to control the process of generation:
10+
- `max_draft_len`: the maximum number of tokens provided as draft tokens in one iteration, which is usually from 4 to 10 in common usage (default value: 4). Empirically, the larger the value is, the higher acceptance rate but higher overhead is expected at the same time, so the right balance based on the models and application scenarios needs to be found.
1111
- `max_matching_ngram_size`: the maximum number of tokens extracted from the tail of the input prompt or generated output as a pattern, which is used to search corresponding draft tokens (default value: 2). Empirically, the larger the value is, the more precise context can be matched from the existed sequence, indicating higher acceptance rate, but the higher probability of miss-match and higher overhead appear, which fall back to normal generation (one token per iteration).
1212
- `device_list`: the index list of device(s) to run the model in V1 workflow. The length of it must be the same as the TP size of the draft model engine. For instances, `device_list=[0]` means using tp_size=1 and GPU 0 for the model, `device_list=[4,5,6,7]` means using tp=4 and GPU from 4 to 7 for the model. This parameter is neddless in V2 workflow.
1313

14-
+ For example, the process of getting draft tokens using `prompt_lookup_num_tokens=2` and `max_matching_ngram_size=4` with a sentence `prefix=[..., t1, t2, t3, t4]` is like below:
14+
+ For example, the process of getting draft tokens using `max_draft_len=2` and `max_matching_ngram_size=4` with a sentence `prefix=[..., t1, t2, t3, t4]` is like below:
1515

1616
```Python
1717
pattern = prefix[:-2] # pattern=[t3, t4] (length=2)
@@ -40,39 +40,39 @@ return None # No any candidate exists
4040
+ We use an open-source `llama-v2-13B` models in this example.
4141
+ `--use_paged_context_fmha=enable` must be specified since we need KVcache reuse in this approach.
4242
+ `--speculative_decoding_mode=draft_tokens_external` must be specified.
43-
+ `--max_draft_len` must be specified larger or equal to `prompt_lookup_num_tokens`.
44-
+ `---prompt_lookup_config` is corresponding configuration of Prompt-Lookup, we can see its usage in [util.py](../util.py).
45-
+ As an example, `[10,2,[0]]` means `prompt_lookup_num_tokens=10`, `max_matching_ngram_size=2`, and device of target model is `GPU0`.
43+
+ `--max_draft_len` must be specified as the length maximum of the draft tokens.
44+
+ `--ngram_config` is corresponding configuration of NGram, we can see its usage in [util.py](../util.py).
45+
+ As an example, `[10,2,[0]]` means `max_draft_len=10`, `max_matching_ngram_size=2`, and device of target model is `GPU0`.
4646
+ `--kv_cache_enable_block_reuse` must be specified for this approach.
4747
+ Only CPP session is supported, so `--use_py_session` must not be specified.
4848
+ `--num_beams` can not be specified as larger than 1 since beam search is not supported in this approach yet.
4949

5050
```bash
5151
# Build engine
5252
python3 examples/models/core/llama/convert_checkpoint.py \
53-
--model_dir=<Path To Llama-v2-13B repo> \
54-
--output_dir=./ckpt-target \
55-
--dtype=float16
53+
--model_dir <Path To Llama-v2-13B repo> \
54+
--output_dir ./ckpt-target \
55+
--dtype float16
5656

5757
trtllm-build \
58-
--checkpoint_dir=./ckpt-target \
59-
--output_dir=./target-engine \
60-
--gemm_plugin=float16 \
61-
--use_paged_context_fmha=enable \
62-
--speculative_decoding_mode=draft_tokens_external \
63-
--max_draft_len=10 \
64-
--max_batch_size=4 \
65-
--max_input_len=3200 \
66-
--max_seq_len=4800
58+
--checkpoint_dir ./ckpt-target \
59+
--output_dir ./target-engine \
60+
--gemm_plugin float16 \
61+
--use_paged_context_fmha enable \
62+
--speculative_decoding_mode draft_tokens_external \
63+
--max_draft_len 10 \
64+
--max_batch_size 4 \
65+
--max_input_len 3200 \
66+
--max_seq_len 4800
6767

6868
# Run decoding
6969
python3 examples/run.py \
7070
--tokenizer_dir <Path To Llama-v2-7B repo> \
7171
--engine_dir ./target-engine \
72-
--prompt_lookup_config="[10,2,[0]]" \
73-
--max_output_len=256 \
72+
--ngram_config "[10,2,[0]]" \
73+
--max_output_len 256 \
7474
--kv_cache_enable_block_reuse \
75-
--input_text="How does Draft-Sampling work?"
75+
--input_text "How does Draft-Sampling work?"
7676

7777
# Run summarization tasks
7878
python examples/summarize.py \
@@ -81,15 +81,17 @@ python examples/summarize.py \
8181
--check_accuracy \
8282
--hf_model_dir <Path To Llama-v2-7B repo> \
8383
--engine_dir ./target-engine \
84-
--batch_size=1 \
85-
--prompt_lookup_config="[10,2,[0]]" \
84+
--batch_size 1 \
85+
--ngram_config "[10,2,[0]]" \
8686
--kv_cache_enable_block_reuse
8787
```
8888

8989
### V2 workflow
9090

9191
```bash
9292
python3 examples/llm-api/quickstart_advanced.py \
93-
--max_matching_ngram_size=2 \
94-
--spec_decode_nextn=4
93+
--spec_decode_nextn 4 \
94+
--max_matching_ngram_size 2 \
95+
--disable_overlap_scheduler \
96+
--disable_kv_cache_reuse
9597
```
Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,20 @@
2323
from tensorrt_llm.runtime import ModelRunnerCpp
2424

2525

26-
class PLDPool: # Ngrams pool for Prompt-Lookup-Decoding
26+
class NgramPool: # Ngrams pool for Ngram
2727

2828
def __init__(
2929
self,
3030
input_batch_size: int,
31-
prompt_lookup_num_tokens: int,
31+
max_draft_len: int,
3232
max_matching_ngram_size: int,
3333
end_id: int,
3434
max_seq_len: list[int],
3535
is_keep_all: bool = True,
3636
is_use_oldest: bool = True,
3737
):
3838
self.input_batch_size = input_batch_size
39-
self.prompt_lookup_num_tokens = prompt_lookup_num_tokens
39+
self.max_draft_len = max_draft_len
4040
self.max_matching_ngram_size = max_matching_ngram_size
4141
self.end_id = end_id
4242
self.max_seq_len = max_seq_len
@@ -45,7 +45,7 @@ def __init__(
4545
self.pool = [{} for _ in range(input_batch_size)]
4646
self.start_index = [0 for _ in range(input_batch_size)]
4747

48-
assert self.prompt_lookup_num_tokens > 0, f"prompt_lookup_num_tokens must be greater than 0, but got {self.prompt_lookup_num_tokens}"
48+
assert self.max_draft_len > 0, f"max_draft_len must be greater than 0, but got {self.max_draft_len}"
4949
assert self.max_matching_ngram_size > 0, f"max_matching_ngram_size must be greater than 0, but got {self.max_matching_ngram_size}"
5050

5151
def print_pool(self):
@@ -82,16 +82,15 @@ def get_draft_tokens(self, prefix: list[torch.Tensor],
8282
-1):
8383
# Find each possible key-value combination, and use tuple for hash
8484
for l in range(len(sequence) - size):
85-
r = min(l + size + self.prompt_lookup_num_tokens,
86-
len(sequence))
85+
r = min(l + size + self.max_draft_len, len(sequence))
8786
key = tuple(sequence[l:l + size])
8887
value = tuple(sequence[l + size:r])
8988
if key not in self.pool[gbi] or not self.is_keep_all or \
90-
len(self.pool[gbi][key][0]) < self.prompt_lookup_num_tokens:
89+
len(self.pool[gbi][key][0]) < self.max_draft_len:
9190
# Update the value if
9291
# 1. the key does not exist
9392
# 2. we only keep the newest one value for each key (MRU)
94-
# 3. the length of the value saved before is less than `prompt_lookup_num_tokens`
93+
# 3. the length of the value saved before is less than `max_draft_len`
9594
self.pool[gbi][key] = OrderedSet((value, ))
9695
elif value not in self.pool[gbi][key]:
9796
# Extend the value if the key is already existed but count of values is not enough
@@ -113,26 +112,26 @@ def get_draft_tokens(self, prefix: list[torch.Tensor],
113112
break
114113
draft_tokens.append(chosen_ids)
115114
self.start_index[gbi] = max(
116-
0, prefix_len[bi] - (self.prompt_lookup_num_tokens +
117-
self.max_matching_ngram_size - 1))
115+
0, prefix_len[bi] -
116+
(self.max_draft_len + self.max_matching_ngram_size - 1))
118117

119118
return draft_tokens, None
120119

121120

122-
def run_dtm_pld(batch_input_ids,
123-
args,
124-
runtime_rank,
125-
end_id,
126-
pad_id,
127-
stop_words_list,
128-
bad_words_list,
129-
vocab_size,
130-
*,
131-
target_runner=None):
132-
# `dtm` for Draft-Target-Model, `pld` for Prompt-Lookup-Decoding
121+
def run_dtm_ngram(batch_input_ids,
122+
args,
123+
runtime_rank,
124+
end_id,
125+
pad_id,
126+
stop_words_list,
127+
bad_words_list,
128+
vocab_size,
129+
*,
130+
target_runner=None):
131+
# `dtm` for Draft-Target-Model, `ngram` for NGram
133132
is_dtm = (args.draft_target_model_config is not None)
134-
is_pld = (args.prompt_lookup_config is not None)
135-
assert is_dtm ^ is_pld, "`--draft_target_model_config` and `--prompt_lookup_config` can not be specified at the same time."
133+
is_ngram = (args.ngram_config is not None)
134+
assert is_dtm ^ is_ngram, "`--draft_target_model_config` and `--ngram_config` can not be specified at the same time."
136135
if is_dtm:
137136
assert args.draft_engine_dir is not None, "`--draft_engine_dir` must be specified in Draft-Target-Model."
138137
draft_len, draft_device_list, target_device_list, use_logits = ast.literal_eval(
@@ -142,12 +141,11 @@ def run_dtm_pld(batch_input_ids,
142141
logger.info(f"Device(s) for draft model: {draft_device_list}")
143142
logger.info(f"Device(s) for target model: {target_device_list}")
144143
logger.info(f"Use logits to accept tokens: {use_logits}")
145-
if is_pld:
146-
logger.info(
147-
f"Using Prompt-Lookup-Decoding speculative decoding V1 workflow")
148-
prompt_lookup_num_tokens, max_matching_ngram_size, target_device_list = ast.literal_eval(
149-
args.prompt_lookup_config)
150-
logger.info(f"prompt_lookup_num_tokens: {prompt_lookup_num_tokens}")
144+
if is_ngram:
145+
logger.info(f"Using NGram speculative decoding V1 workflow")
146+
max_draft_len, max_matching_ngram_size, target_device_list = ast.literal_eval(
147+
args.ngram_config)
148+
logger.info(f"max_draft_len: {max_draft_len}")
151149
logger.info(f"max_matching_ngram_size: {max_matching_ngram_size}")
152150
logger.info(f"Device(s) for the model: {target_device_list}")
153151
use_logits = False # `logits` is useless in this approach yet
@@ -166,9 +164,9 @@ def run_dtm_pld(batch_input_ids,
166164
n_draft_token = [0 for _ in range(input_batch_size)]
167165
n_accept_token = [0 for _ in range(input_batch_size)]
168166

169-
if is_pld:
170-
pld_pool = PLDPool(input_batch_size, prompt_lookup_num_tokens,
171-
max_matching_ngram_size, end_id, max_seq_len)
167+
if is_ngram:
168+
ngram_pool = NgramPool(input_batch_size, max_draft_len,
169+
max_matching_ngram_size, end_id, max_seq_len)
172170

173171
# Repack the output like the output of function `generate`
174172
outputs = {}
@@ -297,8 +295,8 @@ def run_dtm_pld(batch_input_ids,
297295
if use_logits:
298296
d_logits[bi] = draft["generation_logits"][bi, 0,
299297
-d_len[bi]:, :]
300-
if is_pld:
301-
d_ids, d_logits = pld_pool.get_draft_tokens(prefix, batch_slot)
298+
if is_ngram:
299+
d_ids, d_logits = ngram_pool.get_draft_tokens(prefix, batch_slot)
302300
d_len = [len(i) for i in d_ids]
303301

304302
# Run target model
@@ -310,8 +308,8 @@ def run_dtm_pld(batch_input_ids,
310308
draft_logits_list=d_logits)
311309
if is_dtm:
312310
max_new_tokens = draft_len + 1
313-
if is_pld:
314-
max_new_tokens = prompt_lookup_num_tokens + 1
311+
if is_ngram:
312+
max_new_tokens = max_draft_len + 1
315313
target_generation_kwargs.update(max_new_tokens=max_new_tokens)
316314
target = target_runner.generate(**target_generation_kwargs)
317315
torch.cuda.synchronize()

0 commit comments

Comments
 (0)