Skip to content

Commit c18992c

Browse files
default eagle3 num layers and insertion into transforms now happens in llm_args.py
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 4e59e4d commit c18992c

File tree

3 files changed

+41
-49
lines changed

3 files changed

+41
-49
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
from importlib.resources import files
22
from pathlib import Path
3-
from typing import Any, Dict, List, Literal, Optional, Type, Union
3+
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
44

55
import torch
66
from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator
77
from pydantic_settings import BaseSettings, SettingsConfigDict
88

99
from tensorrt_llm.models.modeling_utils import QuantConfig
1010

11-
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
11+
from ...llmapi.llm_args import (
12+
BaseLlmArgs,
13+
BuildConfig,
14+
EagleDecodingConfig,
15+
KvCacheConfig,
16+
SamplerType,
17+
_ParallelConfig,
18+
)
1219
from .models import ModelFactory, ModelFactoryRegistry
1320
from .utils._config import DynamicYamlMixInForSettings
1421
from .utils.logger import ad_logger
@@ -38,6 +45,12 @@ def _check_for_default_value_only(
3845
return value
3946

4047

48+
def default_eagle3_layers_to_capture(num_hidden_layers: int) -> Set[int]:
49+
if num_hidden_layers <= 5:
50+
raise ValueError("Not enough hidden layers for default EAGLE3 capture")
51+
return {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}
52+
53+
4154
_TRANSFORMS_SHORTCUT_LOOKUP = {
4255
"attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"),
4356
"free_mem_ratio": ("resize_kv_cache.free_mem_ratio",),
@@ -150,6 +163,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
150163

151164
enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.")
152165

166+
draft_checkpoint_loader: Optional[object] = Field(
167+
default=None,
168+
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
169+
)
170+
153171
### INFERENCE OPTIMIZER CONFIG #################################################################
154172
mode: Literal["graph", "transformers"] = Field(
155173
default="graph",
@@ -190,11 +208,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
190208
),
191209
)
192210

193-
draft_checkpoint_loader: Optional[object] = Field(
194-
default=None,
195-
description="The checkpoint loader to use for the draft model when using speculative decoding with two models.",
196-
)
197-
198211
### SEQUENCE INTERFACE CONFIG ##################################################################
199212
max_input_len: int = Field(default=1024, description="The maximum input length.")
200213
max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.")
@@ -401,6 +414,26 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A
401414
msg = "AutoDeploy only supports parallelization via the `world_size` argument."
402415
return _check_for_default_value_only(cls, value, info, msg)
403416

417+
@model_validator(mode="after")
418+
def default_eagle3_layers_to_capture(self):
419+
if self.speculative_config is None or not isinstance(
420+
self.speculative_config, EagleDecodingConfig
421+
):
422+
return self
423+
424+
if self.speculative_config.eagle3_layers_to_capture is None:
425+
num_hidden_layers = self.create_factory()._get_model_config()[0].num_hidden_layers
426+
self.speculative_config.eagle3_layers_to_capture = default_eagle3_layers_to_capture(
427+
num_hidden_layers
428+
)
429+
430+
# insert the layers to capture into the transforms config.
431+
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
432+
self.speculative_config.eagle3_layers_to_capture
433+
)
434+
435+
return self
436+
404437
@model_validator(mode="after")
405438
def validate_parallel_config(self):
406439
"""Setup parallel config according to world_size.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -263,33 +263,6 @@ def construct_draft_llm_args(
263263
return draft_llm_args
264264

265265

266-
def _populate_hidden_state_capture_config(ad_config: LlmArgs) -> LlmArgs:
267-
"""Ensure hidden-state capture transforms follow Eagle config.
268-
269-
When users supply an EagleDecodingConfig, automatically wire the
270-
DetectHiddenStatesForCapture transform so it captures the layers
271-
requested by the Eagle config, and make sure the corresponding
272-
cached residual add insertion is enabled.
273-
"""
274-
275-
spec_config = ad_config.speculative_config
276-
if not isinstance(spec_config, EagleDecodingConfig):
277-
return ad_config
278-
279-
spec_dec_mode = spec_config.spec_dec_mode
280-
if not (spec_dec_mode.is_eagle3() or spec_dec_mode.is_eagle3_one_model()):
281-
return ad_config
282-
283-
# DetectHiddenStatesForCapture configuration
284-
capture_key = "detect_hidden_states_for_capture"
285-
capture_cfg = ad_config.transforms.get(capture_key, {"stage": "pattern_matcher"})
286-
if spec_config.eagle3_layers_to_capture is not None:
287-
capture_cfg["eagle3_layers_to_capture"] = spec_config.eagle3_layers_to_capture
288-
ad_config.transforms[capture_key] = capture_cfg
289-
290-
return ad_config
291-
292-
293266
def create_draft_kv_cache_manager_maybe(
294267
draft_model_engine: Optional[PyTorchModelEngine],
295268
ad_config: LlmArgs,
@@ -348,9 +321,6 @@ def build_from_config(cls, ad_config: LlmArgs):
348321

349322
factory = ad_config.create_factory()
350323

351-
# Auto-configure hidden-state capture when using Eagle speculative decoding.
352-
ad_config = _populate_hidden_state_capture_config(ad_config)
353-
354324
# initialize seq info object
355325
seq_info = SequenceInfo(
356326
max_seq_len=max_seq_len,

tests/integration/defs/examples/test_ad_speculative_decoding.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import os
17-
from typing import Set
1817

1918
import pytest
2019
from build_and_run_ad import ExperimentConfig, main
@@ -52,27 +51,17 @@ def get_model_paths():
5251
return base_model, draft_target_model, eagle_model
5352

5453

55-
def get_default_layers_to_capture(num_layers: int) -> Set[int]:
56-
if num_layers <= 5:
57-
raise ValueError("Not enough hidden layers for default EAGLE3 capture")
58-
return {1, num_layers // 2 - 1, num_layers - 4}
59-
60-
6154
def make_spec_config(mode: str, spec_model_path: str):
6255
if mode == "draft_target":
6356
return DraftTargetDecodingConfig(
6457
max_draft_len=DRAFT_TARGET_MAX_DRAFT_LEN, speculative_model_dir=spec_model_path
6558
)
6659
if mode == "eagle":
67-
num_layers = 32 # hardcoded num layers for Llama-3.1-8B-Instruct
68-
69-
# Note: Valid layer indices must be provided in EagleDecodingConfig when using AutoDeploy.
70-
# TODO: Should we have some defaulting behavior in the Engine, similar to PyTorch backend?
7160
return EagleDecodingConfig(
7261
max_draft_len=EAGLE_MAX_DRAFT_LEN,
7362
speculative_model_dir=spec_model_path,
7463
eagle3_one_model=False,
75-
eagle3_layers_to_capture=get_default_layers_to_capture(num_layers),
64+
eagle3_layers_to_capture=None,
7665
)
7766
raise ValueError(f"Unknown speculative mode: {mode}")
7867

0 commit comments

Comments
 (0)