Skip to content

Commit 5e02178

Browse files
committed
Review comments
Signed-off-by: Rahul Tuli <[email protected]>
1 parent 0f533ee commit 5e02178

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

vllm/transformers_utils/configs/speculators/algos.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ def decorator(fn):
1717
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
1818
"""
1919
Apply Eagle-3 specific configuration transformations.
20-
20+
2121
Eagle-3 specific fields:
2222
- draft_vocab_size: Size of the draft model's vocabulary
2323
- target_hidden_size: Hidden size of the target model
2424
- norm_before_residual: Whether to apply norm before residual connection
25+
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
26+
model to use as auxiliary inputs for the Eagle3 drafter. These layers
27+
provide intermediate hidden states that help the drafter make better
28+
predictions. This is the standard field used in Eagle3 checkpoints.
2529
"""
2630

2731
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
@@ -33,5 +37,3 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
3337
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
3438
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
3539
"eagle_aux_hidden_state_layer_ids"]
36-
if config_dict.get("inference_type"):
37-
vllm_config["inference_type"] = config_dict["inference_type"]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2736,14 +2736,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
27362736

27372737
# Try to get auxiliary layers from speculative config,
27382738
# otherwise use model's default layers
2739-
aux_layers = (self._get_eagle3_aux_layers_from_config() or
2740-
self.model.get_eagle3_aux_hidden_state_layers())
2741-
2742-
if aux_layers != self.model.get_eagle3_aux_hidden_state_layers(
2743-
):
2739+
aux_layers = self._get_eagle3_aux_layers_from_config()
2740+
if aux_layers:
27442741
logger.info(
27452742
"Using auxiliary layers from speculative config: %s",
27462743
aux_layers)
2744+
else:
2745+
aux_layers = self.model.get_eagle3_aux_hidden_state_layers(
2746+
)
27472747

27482748
self.model.set_aux_hidden_state_layers(aux_layers)
27492749
time_after_load = time.perf_counter()
@@ -2797,7 +2797,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
27972797
CUDAGraphMode.NONE, self.device)
27982798

27992799
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
2800-
"""Extract Eagle3 auxiliary layer IDs from speculative config.
2800+
"""Extract Eagle3 auxiliary layer indices from speculative config.
2801+
2802+
These indices specify which hidden states from the base model should
2803+
be used as auxiliary inputs for the Eagle3 drafter model during
2804+
speculative decoding.
28012805
28022806
Returns:
28032807
Tuple of layer indices if found in draft model config,
@@ -2807,18 +2811,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
28072811
and self.speculative_config.draft_model_config):
28082812
return None
28092813

2810-
try:
2811-
hf_config = self.speculative_config.draft_model_config.hf_config
2812-
if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'):
2813-
return None
2814-
2815-
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
2816-
if layer_ids and isinstance(layer_ids, (list, tuple)):
2817-
return tuple(layer_ids)
2818-
except Exception as e:
2819-
logger.warning(
2820-
"Failed to read auxiliary layers from speculative config: %s",
2821-
e)
2814+
hf_config = self.speculative_config.draft_model_config.hf_config
2815+
if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'):
2816+
return None
2817+
2818+
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
2819+
if layer_ids and isinstance(layer_ids, (list, tuple)):
2820+
return tuple(layer_ids)
28222821

28232822
return None
28242823

0 commit comments

Comments
 (0)