Skip to content

Commit e607768

Browse files
Speculation: Draft Target in new FW (NVIDIA#4558)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
1 parent cea5dd1 commit e607768

File tree

14 files changed

+193
-27
lines changed

14 files changed

+193
-27
lines changed

examples/pytorch/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,13 @@ python3 examples/pytorch/quickstart_advanced.py \
8181
--max_matching_ngram_size=2 \
8282
--spec_decode_nextn=4
8383
```
84+
85+
```bash
86+
# Draft Taret
87+
python3 examples/pytorch/quickstart_advanced.py \
88+
--model_dir meta-llama/Llama-3.1-8B-Instruct \
89+
--spec_decode_algo draft_target \
90+
--spec_decode_nextn 5 \
91+
--draft_model_dir meta-llama/Llama-3.2-1B-Instruct \
92+
--disable_overlap_scheduler
93+
```

examples/pytorch/quickstart_advanced.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from tensorrt_llm import SamplingParams
44
from tensorrt_llm._torch import LLM
5-
from tensorrt_llm.llmapi import (EagleDecodingConfig, KvCacheConfig,
6-
MTPDecodingConfig, NGramDecodingConfig,
7-
TorchCompileConfig)
5+
from tensorrt_llm.llmapi import (DraftTargetDecodingConfig, EagleDecodingConfig,
6+
KvCacheConfig, MTPDecodingConfig,
7+
NGramDecodingConfig, TorchCompileConfig)
88

99
example_prompts = [
1010
"Hello, my name is",
@@ -109,7 +109,10 @@ def add_llm_args(parser):
109109
# Speculative decoding
110110
parser.add_argument('--spec_decode_algo', type=str, default=None)
111111
parser.add_argument('--spec_decode_nextn', type=int, default=1)
112-
parser.add_argument('--eagle_model_dir', type=str, default=None)
112+
parser.add_argument('--draft_model_dir',
113+
'--eagle_model_dir',
114+
type=str,
115+
default=None)
113116
parser.add_argument('--max_matching_ngram_size', type=int, default=5)
114117
parser.add_argument('--use_one_model', default=False, action='store_true')
115118

@@ -166,8 +169,12 @@ def setup_llm(args):
166169
elif spec_decode_algo == "EAGLE3":
167170
spec_config = EagleDecodingConfig(
168171
max_draft_len=args.spec_decode_nextn,
169-
pytorch_eagle_weights_path=args.eagle_model_dir,
172+
pytorch_weights_path=args.draft_model_dir,
170173
eagle3_one_model=args.use_one_model)
174+
elif spec_decode_algo == "DRAFT_TARGET":
175+
spec_config = DraftTargetDecodingConfig(
176+
max_draft_len=args.spec_decode_nextn,
177+
pytorch_weights_path=args.draft_model_dir)
171178
elif spec_decode_algo == "NGRAM":
172179
spec_config = NGramDecodingConfig(
173180
prompt_lookup_num_tokens=args.spec_decode_nextn,

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,9 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1):
307307
if request.state != LlmRequestState.GENERATION_COMPLETE:
308308
new_token = new_tokens_list[token_idx]
309309
num_tokens = request.add_new_token(new_token, beam_idx)
310-
self._handle_stop_criteria(request, new_token, num_tokens,
311-
beam_idx)
310+
if self._handle_stop_criteria(request, new_token, num_tokens,
311+
beam_idx):
312+
continue
312313

313314
# Accept draft tokens (if we have any) if and only if they match the new
314315
# token exactly.

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .draft_target import DraftTargetConfig
12
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
23
from .interface import SpecConfig, SpecMetadata
34
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
@@ -9,5 +10,5 @@
910
"SpecConfig", "SpecMetadata", "MTPConfig", "MTPEagleWorker",
1011
"MTPSpecMetadata", "MTPWorker", "Eagle3Config", "Eagle3SpecMetadata",
1112
"get_spec_metadata", "get_spec_resource_manager", "get_spec_decoder",
12-
"get_num_spec_layers", "get_spec_worker", "NGramConfig"
13+
"get_num_spec_layers", "get_spec_worker", "NGramConfig", "DraftTargetConfig"
1314
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
5+
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
6+
7+
8+
@dataclass
9+
class DraftTargetConfig(SpecConfig):
10+
spec_dec_name: str = "DRAFT_TARGET"
11+
12+
def __post_init__(self):
13+
if self.draft_model_path is None:
14+
raise ValueError("Path to Draft weights must be specified.")
15+
16+
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
17+
self.spec_dec_name)
18+
self.num_extra_kv_tokens = 0
19+
20+
def update_from_model_config(self, model_config):
21+
pass
22+
23+
def get_draft_model_prompt(self,
24+
input_tokens: torch.Tensor) -> torch.Tensor:
25+
return input_tokens
26+
27+
28+
@dataclass
29+
class DraftTargetSpecMetadata(SpecMetadata):
30+
31+
def __post_init__(self):
32+
pass
33+
34+
def prepare(self):
35+
pass

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class SpeculativeDecodingMode(IntEnum):
1616
EAGLE3 = auto()
1717
EAGLE3_ONE_MODEL = auto()
1818
NGRAM = auto()
19+
DRAFT_TARGET = auto()
1920
NONE = auto()
2021

2122
def is_mtp(self):
@@ -39,6 +40,9 @@ def is_ngram(self):
3940
def is_none(self):
4041
return self == SpeculativeDecodingMode.NONE
4142

43+
def is_draft_target(self):
44+
return self == SpeculativeDecodingMode.DRAFT_TARGET
45+
4246
def without_logits(self):
4347
return self.is_mtp() or self.is_eagle3_one_model()
4448

@@ -49,7 +53,7 @@ def support_overlap_scheduler(self):
4953
return self.is_mtp() or self.is_eagle3_one_model()
5054

5155
def has_draft_model(self):
52-
return self.is_eagle3()
56+
return self.is_eagle3() or self.is_draft_target()
5357

5458
def needs_kv_cache_recompute(self):
5559
"""
@@ -77,8 +81,8 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
7781
"""
7882

7983
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
80-
return (self.is_eagle3()
81-
and not (issubclass(attention_backend, TrtllmAttention)
84+
return ((self.is_eagle3() or self.is_draft_target())
85+
and not (isinstance(attention_backend, TrtllmAttention)
8286
and get_sm_version() == 100)) or self.is_ngram()
8387

8488
@staticmethod

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .draft_target import DraftTargetSpecMetadata
12
from .eagle3 import (Eagle3OneModelDecoder, Eagle3OneModelSpecMetadata,
23
Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3Sampler,
34
Eagle3SpecMetadata)
@@ -36,6 +37,11 @@ def get_spec_metadata(spec_config,
3637
num_layers=spec_config.num_layers,
3738
hidden_size=spec_config.hidden_size,
3839
max_num_tokens=max_num_tokens)
40+
elif spec_config.spec_dec_mode.is_draft_target():
41+
return DraftTargetSpecMetadata(
42+
max_draft_tokens=spec_config.max_draft_tokens,
43+
spec_dec_mode=spec_config.spec_dec_mode,
44+
max_num_requests=max_num_requests)
3945
else:
4046
return None
4147

tensorrt_llm/llmapi/__init__.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from ..sampling_params import GuidedDecodingParams, SamplingParams
44
from .build_cache import BuildCacheConfig
55
from .llm import LLM, RequestOutput
6+
# yapf: disable
67
from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
78
CapacitySchedulerPolicy, ContextChunkingPolicy,
8-
DynamicBatchConfig, EagleDecodingConfig,
9-
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
10-
LookaheadDecodingConfig, MedusaDecodingConfig,
11-
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
12-
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs)
9+
DraftTargetDecodingConfig, DynamicBatchConfig,
10+
EagleDecodingConfig, ExtendedRuntimePerfKnobConfig,
11+
KvCacheConfig, LlmArgs, LookaheadDecodingConfig,
12+
MedusaDecodingConfig, MTPDecodingConfig,
13+
NGramDecodingConfig, SchedulerConfig, TorchCompileConfig,
14+
TorchLlmArgs, TrtLlmArgs)
15+
# yapf: enable
1316
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
1417
QuantConfig)
1518
from .mpi_session import MpiCommSession
@@ -43,6 +46,7 @@
4346
'CacheTransceiverConfig',
4447
'NGramDecodingConfig',
4548
'TorchCompileConfig',
49+
'DraftTargetDecodingConfig',
4650
'LlmArgs',
4751
'TorchLlmArgs',
4852
'TrtLlmArgs',

tensorrt_llm/llmapi/llm_args.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def from_dict(cls, data: dict):
207207
"Eagle": EagleDecodingConfig,
208208
"Lookahead": LookaheadDecodingConfig,
209209
"NGram": NGramDecodingConfig,
210+
"DraftTarget": DraftTargetDecodingConfig,
210211
}
211212

212213
config_class = config_classes.get(decoding_type)
@@ -238,7 +239,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
238239
dynamic_tree_max_topK: Optional[int] = None
239240
num_eagle_layers: Optional[int] = None
240241
max_non_leaves_per_layer: Optional[int] = None
241-
pytorch_eagle_weights_path: Optional[str] = None
242+
pytorch_weights_path: Optional[str] = None
242243
eagle3_one_model: Optional[bool] = True
243244

244245
@classmethod
@@ -282,6 +283,16 @@ def from_dict(cls, data: dict):
282283
decoding_type: ClassVar[str] = "NGram"
283284

284285

286+
class DraftTargetDecodingConfig(DecodingBaseConfig):
287+
pytorch_weights_path: Optional[str] = None
288+
289+
@classmethod
290+
def from_dict(cls, data: dict):
291+
return cls(**data)
292+
293+
decoding_type: ClassVar[str] = "DraftTarget"
294+
295+
285296
class MTPDecodingConfig(DecodingBaseConfig):
286297
num_nextn_predict_layers: Optional[int] = 1
287298
use_relaxed_acceptance_for_thinking: Optional[bool] = False
@@ -896,10 +907,11 @@ class BaseLlmArgs(BaseModel):
896907
default=None, description="Cache transceiver config.")
897908

898909
# Speculative decoding parameters
899-
speculative_config: Optional[Union[
900-
LookaheadDecodingConfig, MedusaDecodingConfig, EagleDecodingConfig,
901-
MTPDecodingConfig, NGramDecodingConfig]] = Field(
902-
default=None, description="Speculative decoding config.")
910+
speculative_config: Optional[
911+
Union[LookaheadDecodingConfig, MedusaDecodingConfig,
912+
EagleDecodingConfig, MTPDecodingConfig, NGramDecodingConfig,
913+
DraftTargetDecodingConfig]] = Field(
914+
default=None, description="Speculative decoding config.")
903915

904916
batching_type: Optional[BatchingType] = Field(default=None,
905917
description="Batching type.")
@@ -1302,7 +1314,7 @@ def validate_speculative_config(self):
13021314
self.speculative_config = Eagle3Config(
13031315
max_draft_tokens=self.speculative_config.max_draft_len,
13041316
draft_model_path=self.speculative_config.
1305-
pytorch_eagle_weights_path,
1317+
pytorch_weights_path,
13061318
eagle3_one_model=self.speculative_config.
13071319
eagle3_one_model)
13081320
elif isinstance(self.speculative_config, NGramDecodingConfig):
@@ -1320,6 +1332,16 @@ def validate_speculative_config(self):
13201332
is_use_oldest=self.speculative_config.is_use_oldest,
13211333
is_public_pool=self.speculative_config.is_public_pool,
13221334
)
1335+
elif isinstance(self.speculative_config, DraftTargetDecodingConfig):
1336+
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
1337+
assert self.backend == 'pytorch'
1338+
assert self.speculative_config.max_draft_len > 0
1339+
self.build_config.max_draft_len = self.speculative_config.max_draft_len
1340+
from tensorrt_llm._torch.speculative import DraftTargetConfig
1341+
self.speculative_config = DraftTargetConfig(
1342+
max_draft_tokens=self.speculative_config.max_draft_len,
1343+
draft_model_path=self.speculative_config.
1344+
pytorch_weights_path)
13231345
elif isinstance(self.speculative_config, MTPDecodingConfig):
13241346
from tensorrt_llm._torch.speculative import MTPConfig
13251347
self.speculative_config = MTPConfig(

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from ..module import Module
3131
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
3232
get_build_cache_config_from_env)
33-
from .llm_args import (CalibConfig, EagleDecodingConfig, KvCacheConfig, LlmArgs,
33+
from .llm_args import (CalibConfig, DraftTargetDecodingConfig,
34+
EagleDecodingConfig, KvCacheConfig, LlmArgs,
3435
LookaheadDecodingConfig, MedusaDecodingConfig,
3536
MTPDecodingConfig, NGramDecodingConfig, _ModelFormatKind,
3637
_ModelWrapper, _ParallelConfig, get_model_format,
@@ -871,6 +872,7 @@ class LlmBuildStats:
871872
'MedusaDecodingConfig',
872873
'MTPDecodingConfig',
873874
'NGramDecodingConfig',
875+
'DraftTargetDecodingConfig',
874876
'ContextChunkingPolicy',
875877
'CapacitySchedulerPolicy',
876878
'BuildConfig',

0 commit comments

Comments
 (0)