Skip to content

Commit 2158396

Browse files
committed
Review comments
Signed-off-by: Rahul Tuli <[email protected]>
1 parent 1903535 commit 2158396

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

vllm/transformers_utils/configs/speculators/algos.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ def decorator(fn):
1717
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
1818
"""
1919
Apply Eagle-3 specific configuration transformations.
20-
20+
2121
Eagle-3 specific fields:
2222
- draft_vocab_size: Size of the draft model's vocabulary
2323
- target_hidden_size: Hidden size of the target model
2424
- norm_before_residual: Whether to apply norm before residual connection
25+
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
26+
model to use as auxiliary inputs for the Eagle3 drafter. These layers
27+
provide intermediate hidden states that help the drafter make better
28+
predictions. This is the standard field used in Eagle3 checkpoints.
2529
"""
2630

2731
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
@@ -33,5 +37,3 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
3337
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
3438
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
3539
"eagle_aux_hidden_state_layer_ids"]
36-
if config_dict.get("inference_type"):
37-
vllm_config["inference_type"] = config_dict["inference_type"]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2753,14 +2753,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
27532753

27542754
# Try to get auxiliary layers from speculative config,
27552755
# otherwise use model's default layers
2756-
aux_layers = (self._get_eagle3_aux_layers_from_config() or
2757-
self.model.get_eagle3_aux_hidden_state_layers())
2758-
2759-
if aux_layers != self.model.get_eagle3_aux_hidden_state_layers(
2760-
):
2756+
aux_layers = self._get_eagle3_aux_layers_from_config()
2757+
if aux_layers:
27612758
logger.info(
27622759
"Using auxiliary layers from speculative config: %s",
27632760
aux_layers)
2761+
else:
2762+
aux_layers = self.model.get_eagle3_aux_hidden_state_layers(
2763+
)
27642764

27652765
self.model.set_aux_hidden_state_layers(aux_layers)
27662766
time_after_load = time.perf_counter()
@@ -2814,7 +2814,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
28142814
CUDAGraphMode.NONE, self.device)
28152815

28162816
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
2817-
"""Extract Eagle3 auxiliary layer IDs from speculative config.
2817+
"""Extract Eagle3 auxiliary layer indices from speculative config.
2818+
2819+
These indices specify which hidden states from the base model should
2820+
be used as auxiliary inputs for the Eagle3 drafter model during
2821+
speculative decoding.
28182822
28192823
Returns:
28202824
Tuple of layer indices if found in draft model config,
@@ -2824,18 +2828,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
28242828
and self.speculative_config.draft_model_config):
28252829
return None
28262830

2827-
try:
2828-
hf_config = self.speculative_config.draft_model_config.hf_config
2829-
if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'):
2830-
return None
2831-
2832-
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
2833-
if layer_ids and isinstance(layer_ids, (list, tuple)):
2834-
return tuple(layer_ids)
2835-
except Exception as e:
2836-
logger.warning(
2837-
"Failed to read auxiliary layers from speculative config: %s",
2838-
e)
2831+
hf_config = self.speculative_config.draft_model_config.hf_config
2832+
if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'):
2833+
return None
2834+
2835+
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
2836+
if layer_ids and isinstance(layer_ids, (list, tuple)):
2837+
return tuple(layer_ids)
28392838

28402839
return None
28412840

0 commit comments

Comments
 (0)