Skip to content

Commit 792031e

Browse files
fixing comments from AI
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 17029de commit 792031e

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _check_for_default_value_only(
4646

4747

4848
def default_eagle3_layers_to_capture(num_hidden_layers: int) -> Set[int]:
49-
if num_hidden_layers <= 5:
49+
if num_hidden_layers <= 6:
5050
raise ValueError("Not enough hidden layers for default EAGLE3 capture")
5151
return {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}
5252

@@ -415,7 +415,7 @@ def ensure_no_custom_parallel_config(cls, value: Any, info: ValidationInfo) -> A
415415
return _check_for_default_value_only(cls, value, info, msg)
416416

417417
@model_validator(mode="after")
418-
def default_eagle3_layers_to_capture(self):
418+
def set_eagle3_layers_to_capture(self):
419419
if self.speculative_config is None or not isinstance(
420420
self.speculative_config, EagleDecodingConfig
421421
):
@@ -428,6 +428,12 @@ def default_eagle3_layers_to_capture(self):
428428
)
429429

430430
# insert the layers to capture into the transforms config.
431+
if self.transforms is None:
432+
self.transforms = {}
433+
434+
if "detect_hidden_states_for_capture" not in self.transforms:
435+
self.transforms["detect_hidden_states_for_capture"] = {}
436+
431437
self.transforms["detect_hidden_states_for_capture"]["eagle3_layers_to_capture"] = (
432438
self.speculative_config.eagle3_layers_to_capture
433439
)

tests/integration/defs/examples/test_ad_speculative_decoding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import os
17+
from typing import Optional
1718

1819
import pytest
1920
from build_and_run_ad import ExperimentConfig, main
@@ -65,7 +66,9 @@ def make_spec_config(spec_dec_mode: str, spec_model_path: str):
6566
raise ValueError(f"Unknown speculative mode: {spec_dec_mode}")
6667

6768

68-
def run_with_autodeploy(model, speculative_model_dir, batch_size, spec_dec_mode: str | None):
69+
def run_with_autodeploy(
70+
model, speculative_model_dir, batch_size, spec_dec_mode: Optional[str] = None
71+
):
6972
"""Run AutoDeploy with or without speculative decoding.
7073
7174
Args:

0 commit comments

Comments
 (0)