|
1 | 1 | from importlib.resources import files |
2 | 2 | from pathlib import Path |
3 | | -from typing import Any, Dict, List, Literal, Optional, Type, Union |
| 3 | +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union |
4 | 4 |
|
5 | 5 | import torch |
6 | 6 | from pydantic import Field, PrivateAttr, ValidationInfo, field_validator, model_validator |
7 | 7 | from pydantic_settings import BaseSettings, SettingsConfigDict |
8 | 8 |
|
9 | 9 | from tensorrt_llm.models.modeling_utils import QuantConfig |
10 | 10 |
|
11 | | -from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig |
| 11 | +from ...llmapi.llm_args import ( |
| 12 | + BaseLlmArgs, |
| 13 | + BuildConfig, |
| 14 | + EagleDecodingConfig, |
| 15 | + KvCacheConfig, |
| 16 | + SamplerType, |
| 17 | + _ParallelConfig, |
| 18 | +) |
12 | 19 | from .models import ModelFactory, ModelFactoryRegistry |
13 | 20 | from .utils._config import DynamicYamlMixInForSettings |
14 | 21 | from .utils.logger import ad_logger |
@@ -38,6 +45,12 @@ def _check_for_default_value_only( |
38 | 45 | return value |
39 | 46 |
|
40 | 47 |
|
| 48 | +def default_eagle3_layers_to_capture(num_hidden_layers: int) -> Set[int]: |
| 49 | + if num_hidden_layers <= 5: |
| 50 | + raise ValueError("Not enough hidden layers for default EAGLE3 capture") |
| 51 | + return {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4} |
| 52 | + |
| 53 | + |
41 | 54 | _TRANSFORMS_SHORTCUT_LOOKUP = { |
42 | 55 | "attn_backend": ("insert_cached_attention.backend", "transformers_replace_cached_attn.backend"), |
43 | 56 | "free_mem_ratio": ("resize_kv_cache.free_mem_ratio",), |
@@ -150,6 +163,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): |
150 | 163 |
|
151 | 164 | enable_chunked_prefill: bool = Field(default=False, description="Enable chunked prefill.") |
152 | 165 |
|
| 166 | + draft_checkpoint_loader: Optional[object] = Field( |
| 167 | + default=None, |
| 168 | + description="The checkpoint loader to use for the draft model when using speculative decoding with two models.", |
| 169 | + ) |
| 170 | + |
153 | 171 | ### INFERENCE OPTIMIZER CONFIG ################################################################# |
154 | 172 | mode: Literal["graph", "transformers"] = Field( |
155 | 173 | default="graph", |
@@ -190,11 +208,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): |
190 | 208 | ), |
191 | 209 | ) |
192 | 210 |
|
193 | | - draft_checkpoint_loader: Optional[object] = Field( |
194 | | - default=None, |
195 | | - description="The checkpoint loader to use for the draft model when using speculative decoding with two models.", |
196 | | - ) |
197 | | - |
198 | 211 | ### SEQUENCE INTERFACE CONFIG ################################################################## |
199 | 212 | max_input_len: int = Field(default=1024, description="The maximum input length.") |
200 | 213 | max_num_tokens: Optional[int] = Field(default=None, description="The maximum number of tokens.") |
@@ -401,6 +414,26 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A |
401 | 414 | msg = "AutoDeploy only supports parallelization via the `world_size` argument." |
402 | 415 | return _check_for_default_value_only(cls, value, info, msg) |
403 | 416 |
|
| 417 | + @model_validator(mode="after") |
| 418 | + def default_eagle3_layers_to_capture(self): |
| 419 | + if self.speculative_config is None or not isinstance( |
| 420 | + self.speculative_config, EagleDecodingConfig |
| 421 | + ): |
| 422 | + return self |
| 423 | + |
| 424 | + if self.speculative_config.eagle3_layers_to_capture is None: |
| 425 | + num_hidden_layers = self.create_factory()._get_model_config()[0].num_hidden_layers |
| 426 | + self.speculative_config.eagle3_layers_to_capture = default_eagle3_layers_to_capture( |
| 427 | + num_hidden_layers |
| 428 | + ) |
| 429 | + |
| 430 | + # insert the layers to capture into the transforms config. |
| 431 | + self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = ( |
| 432 | + self.speculative_config.eagle3_layers_to_capture |
| 433 | + ) |
| 434 | + |
| 435 | + return self |
| 436 | + |
404 | 437 | @model_validator(mode="after") |
405 | 438 | def validate_parallel_config(self): |
406 | 439 | """Setup parallel config according to world_size. |
|
0 commit comments