Skip to content

Commit 8ecdeee

Browse files
authored
[refactor] Simplification of Speculative decoding configs - Part 2 (NVIDIA#5936)
Signed-off-by: wili-65535 <[email protected]> Co-authored-by: wili-65535 <[email protected]>
1 parent bc2fb29 commit 8ecdeee

File tree

9 files changed

+60
-37
lines changed

9 files changed

+60
-37
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tensorrt_llm.mapping import Mapping
1919

2020
from ..model_config import ModelConfig
21-
from ..speculative import get_spec_decoder
21+
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
2222
from .config import PyTorchConfig
2323
from .config_utils import is_mla, is_nemotron_hybrid
2424
from .guided_decoder import GuidedDecoder
@@ -164,7 +164,7 @@ def _get_token_num_for_estimation(self) -> int:
164164

165165
if spec_cfg is not None:
166166
num_extra_tokens_per_seq += spec_cfg.max_draft_len
167-
num_extra_tokens_per_seq += spec_cfg.num_extra_kv_tokens
167+
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
168168
for req in self._dummy_reqs:
169169
num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq
170170
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from tensorrt_llm._torch.models.checkpoints.base_checkpoint_loader import \
1919
BaseCheckpointLoader
2020
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
21+
from tensorrt_llm._torch.speculative import (
22+
get_num_extra_kv_tokens, update_spec_config_from_model_config)
2123
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
2224
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
2325
torch_dtype_to_str, trace_func)
@@ -353,7 +355,8 @@ def __init__(
353355

354356
if self.is_spec_decode:
355357
self.spec_metadata = None
356-
self.spec_config.update_from_model_config(self.model.config)
358+
update_spec_config_from_model_config(self.spec_config,
359+
self.model.config)
357360
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size
358361
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
359362
dtype=torch.int,
@@ -1442,8 +1445,7 @@ def previous_seq_slots_device():
14421445
attn_metadata.kv_cache_params = KVCacheParams(
14431446
use_cache=True,
14441447
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
1445-
num_extra_kv_tokens=0 if self.spec_config is None else
1446-
self.spec_config.num_extra_kv_tokens)
1448+
num_extra_kv_tokens=get_num_extra_kv_tokens(self.spec_config))
14471449
attn_metadata.kv_cache_manager = kv_cache_manager
14481450

14491451
attn_metadata.prepare()

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from ..attention_backend.interface import AttentionRuntimeFeatures
2121
from ..distributed import MPIDist
22-
from ..speculative import get_spec_drafter, get_spec_resource_manager
22+
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
23+
get_spec_resource_manager)
2324
from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
2425
create_py_executor_instance, instantiate_sampler, is_mla)
2526
from .config import PyTorchConfig
@@ -266,7 +267,7 @@ def create_py_executor(
266267
max_seq_len += spec_config.max_draft_len
267268

268269
if spec_config is not None:
269-
max_seq_len += spec_config.num_extra_kv_tokens
270+
max_seq_len += get_num_extra_kv_tokens(spec_config)
270271
max_seq_len += spec_config.max_draft_len
271272

272273
executor_config.max_seq_len = max_seq_len

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def __init__(
176176
self.kv_factor = 1 if kv_cache_type == CacheTypeCpp.SELFKONLY else 2
177177
# Some speculative decoding methods need to use different kv lengths for the
178178
# draft/target layers. Add extra tokens to handle this issue.
179-
self.num_extra_kv_tokens = 0 if spec_config is None else spec_config.num_extra_kv_tokens
179+
# Import here to avoid circular imports
180+
from ..speculative import get_num_extra_kv_tokens
181+
self.num_extra_kv_tokens = get_num_extra_kv_tokens(spec_config)
180182
self.event_buffer_max_size = kv_cache_config.event_buffer_max_size
181183
self.max_num_tokens = max_num_tokens
182184

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from .interface import SpecMetadata
33
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
44
from .ngram import NGramDrafter, NGramPoolManager
5-
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_drafter,
6-
get_spec_metadata, get_spec_resource_manager,
7-
get_spec_worker)
5+
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
6+
get_spec_decoder, get_spec_drafter, get_spec_metadata,
7+
get_spec_resource_manager, get_spec_worker,
8+
update_spec_config_from_model_config)
89

910
__all__ = [
1011
"Eagle3SpecMetadata",
@@ -14,10 +15,12 @@
1415
"NGramDrafter",
1516
"NGramPoolManager",
1617
"SpecMetadata",
18+
"get_num_extra_kv_tokens",
1719
"get_num_spec_layers",
1820
"get_spec_decoder",
1921
"get_spec_drafter",
2022
"get_spec_metadata",
2123
"get_spec_resource_manager",
2224
"get_spec_worker",
25+
"update_spec_config_from_model_config",
2326
]

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import traceback
44
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
55

6+
import torch
7+
68
from tensorrt_llm._utils import nvtx_range
79
from tensorrt_llm.logger import logger
810

@@ -15,6 +17,20 @@
1517

1618
if TYPE_CHECKING:
1719
from ..pyexecutor.model_engine import ModelEngine
20+
from .interface import SpeculativeDecodingMode
21+
22+
23+
# Place the tool function here to avoid circular import
24+
def get_draft_model_prompt(spec_dec_mode: SpeculativeDecodingMode,
25+
input_tokens: torch.Tensor) -> torch.Tensor:
26+
"""
27+
Can be used to modify prompts for speculative algorithms that need to update tokens
28+
before drafting.
29+
"""
30+
if spec_dec_mode.is_eagle3():
31+
# EAGLE3 always throws away the first token when processing draft inputs
32+
return input_tokens[1:]
33+
return input_tokens
1834

1935

2036
class ModelDrafter(Drafter):
@@ -113,8 +129,8 @@ def _create_draft_request_for_request(
113129
"""Create a draft request based on the original request state."""
114130
num_draft_tokens, num_accepted_tokens = self._initialize_draft_tokens(
115131
request)
116-
input_tokens = self.spec_config.get_draft_model_prompt(
117-
request.get_tokens()[0])
132+
input_tokens = get_draft_model_prompt(self.spec_config.spec_dec_mode,
133+
request.get_tokens()[0])
118134

119135
# First time seeing this request - context request
120136
if request.max_beam_num_tokens - 1 == request.py_prompt_len:

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,24 @@ def get_spec_worker(spec_config, mapping):
153153
if spec_config.spec_dec_mode.is_eagle3_one_model():
154154
return Eagle3OneModelWorker(spec_config, mapping)
155155
return None
156+
157+
158+
def get_num_extra_kv_tokens(spec_config):
159+
"""
160+
Implementation detail for one model implementations of speculative decoding. Extra
161+
KV cache tokens are required.
162+
"""
163+
if spec_config is None:
164+
return 0
165+
if spec_config.spec_dec_mode.is_eagle3_one_model(
166+
) or spec_config.spec_dec_mode.is_mtp_eagle():
167+
return spec_config.max_draft_len - 1
168+
return 0
169+
170+
171+
def update_spec_config_from_model_config(spec_config, model_config):
172+
if spec_config.spec_dec_mode.is_mtp():
173+
# Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them.
174+
spec_config.max_draft_len = spec_config.num_nextn_predict_layers
175+
# Use `num_nextn_predict_layers_from_model_config` to decide decoding mode MTP / MTP_EAGLE.
176+
spec_config.num_nextn_predict_layers_from_model_config = model_config.num_nextn_predict_layers

tensorrt_llm/llmapi/llm_args.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ class _ModelFormatKind(Enum):
248248
class DecodingBaseConfig(BaseModel):
249249
max_draft_len: Optional[int] = None
250250
speculative_model_dir: Optional[Union[str, Path]] = None
251-
num_extra_kv_tokens: int = 0
252251

253252
@classmethod
254253
def from_dict(cls, data: dict):
@@ -295,13 +294,6 @@ def spec_dec_mode(self):
295294
return TorchSpeculativeDecodingMode.from_string(
296295
self.decoding_type.upper())
297296

298-
def update_from_model_config(self, model_config):
299-
pass
300-
301-
def get_draft_model_prompt(self,
302-
input_tokens: torch.Tensor) -> torch.Tensor:
303-
return input_tokens
304-
305297

306298
class MedusaDecodingConfig(DecodingBaseConfig):
307299
medusa_choices: Optional[List[List[int]]] = None
@@ -345,13 +337,6 @@ def spec_dec_mode(self):
345337
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
346338
return TorchSpeculativeDecodingMode.EAGLE3
347339

348-
def get_draft_model_prompt(self,
349-
input_tokens: torch.Tensor) -> torch.Tensor:
350-
"""
351-
Eagle3 always throws away the first token when processing draft inputs
352-
"""
353-
return input_tokens[1:]
354-
355340

356341
class UserProvidedDecodingConfig(DecodingBaseConfig):
357342
# Cannot use real type annotations due to circular imports
@@ -448,11 +433,6 @@ def spec_dec_mode(self):
448433
return TorchSpeculativeDecodingMode.MTP_EAGLE
449434
return TorchSpeculativeDecodingMode.MTP
450435

451-
def update_from_model_config(self, model_config):
452-
assert self.num_nextn_predict_layers > 0
453-
if model_config.num_nextn_predict_layers == 1 and not self.use_mtp_vanilla:
454-
self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1
455-
456436

457437
class PybindMirror(ABC):
458438
''' A class containing the utilities for mirroring Python classes to
@@ -1468,8 +1448,6 @@ def validate_speculative_config(self):
14681448
assert self.speculative_config.speculative_model_dir is not None, "Path to EAGLE3 weights must be specified."
14691449
self.build_config.max_draft_len = self.speculative_config.max_draft_len
14701450
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE
1471-
if self.speculative_config.eagle3_one_model:
1472-
self.speculative_config.num_extra_kv_tokens = self.speculative_config.max_draft_len - 1
14731451
if self.backend not in ['pytorch', '_autodeploy']:
14741452
eagle_config = _EagleConfig(
14751453
self.speculative_config.eagle_choices,
@@ -1490,6 +1468,7 @@ def validate_speculative_config(self):
14901468
elif isinstance(self.speculative_config, DraftTargetDecodingConfig):
14911469
assert self.backend in ['pytorch']
14921470
assert self.speculative_config.max_draft_len > 0
1471+
assert self.speculative_config.speculative_model_dir is not None, "Path to draft model must be specified."
14931472
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL
14941473
self.build_config.max_draft_len = self.speculative_config.max_draft_len
14951474

tests/unittest/_torch/speculative/test_draft_target.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def test_llama_draft_target(use_cuda_graph: bool, attn_backend: str):
4949
)
5050

5151
prompts = [
52-
#"The capital of France is", # Waive this prompt to avoid a flaky error, https://nvbugspro.nvidia.com/bug/5374319
53-
"The capital of Germany is",
52+
"The capital of France is",
5453
"The president of the United States is",
5554
]
5655
sampling_params = SamplingParams(max_tokens=32)

0 commit comments

Comments
 (0)