Skip to content

Commit 602fa0e

Browse files
fixing comments from AI
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 9b6838c commit 602fa0e

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _check_for_default_value_only(
4646

4747

4848
def default_eagle3_layers_to_capture(num_hidden_layers: int) -> Set[int]:
49-
if num_hidden_layers <= 5:
49+
if num_hidden_layers <= 6:
5050
raise ValueError("Not enough hidden layers for default EAGLE3 capture")
5151
return {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}
5252

@@ -434,7 +434,7 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A
434434
return _check_for_default_value_only(cls, value, info, msg)
435435

436436
@model_validator(mode="after")
437-
def default_eagle3_layers_to_capture(self):
437+
def set_eagle3_layers_to_capture(self):
438438
if self.speculative_config is None or not isinstance(
439439
self.speculative_config, EagleDecodingConfig
440440
):
@@ -447,6 +447,12 @@ def default_eagle3_layers_to_capture(self):
447447
)
448448

449449
# insert the layers to capture into the transforms config.
450+
if self.transforms is None:
451+
self.transforms = {}
452+
453+
if "detect_hidden_states_for_capture" not in self.transforms:
454+
self.transforms["detect_hidden_states_for_capture"] = {}
455+
450456
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
451457
self.speculative_config.eagle3_layers_to_capture
452458
)

tests/integration/defs/examples/test_ad_speculative_decoding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import os
17+
from typing import Optional
1718

1819
import pytest
1920
from build_and_run_ad import ExperimentConfig, main
@@ -65,7 +66,9 @@ def make_spec_config(spec_dec_mode: str, spec_model_path: str):
6566
raise ValueError(f"Unknown speculative mode: {spec_dec_mode}")
6667

6768

68-
def run_with_autodeploy(model, speculative_model_dir, batch_size, spec_dec_mode: str | None):
69+
def run_with_autodeploy(
70+
model, speculative_model_dir, batch_size, spec_dec_mode: Optional[str] = None
71+
):
6972
"""Run AutoDeploy with or without speculative decoding.
7073
7174
Args:

0 commit comments

Comments
 (0)