Skip to content

Commit 1c1d679

Browse files
committed
Review comments
Signed-off-by: Rahul Tuli <[email protected]>
1 parent 06c6c93 commit 1c1d679

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

vllm/transformers_utils/configs/speculators/algos.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
2121
- draft_vocab_size: Size of the draft model's vocabulary
2222
- target_hidden_size: Hidden size of the target model
2323
- norm_before_residual: Whether to apply norm before residual connection
24+
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
25+
model to use as auxiliary inputs for the Eagle3 drafter. These layers
26+
provide intermediate hidden states that help the drafter make better
27+
predictions. This is the standard field used in Eagle3 checkpoints.
2428
"""
2529

2630
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
@@ -31,5 +35,3 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
3135
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
3236
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
3337
"eagle_aux_hidden_state_layer_ids"]
34-
if config_dict.get("inference_type"):
35-
vllm_config["inference_type"] = config_dict["inference_type"]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,19 +2951,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
29512951

29522952
# Try to get auxiliary layers from speculative config,
29532953
# otherwise use model's default layers
2954-
aux_layers = (
2955-
self._get_eagle3_aux_layers_from_config()
2956-
or self.model.get_eagle3_aux_hidden_state_layers()
2957-
)
2958-
2959-
if (
2960-
aux_layers
2961-
!= self.model.get_eagle3_aux_hidden_state_layers()
2962-
):
2954+
aux_layers = self._get_eagle3_aux_layers_from_config()
2955+
if aux_layers:
29632956
logger.info(
29642957
"Using auxiliary layers from speculative config: %s",
29652958
aux_layers,
29662959
)
2960+
else:
2961+
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
29672962

29682963
self.model.set_aux_hidden_state_layers(aux_layers)
29692964
time_after_load = time.perf_counter()
@@ -3021,7 +3016,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
30213016
)
30223017

30233018
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
3024-
"""Extract Eagle3 auxiliary layer IDs from speculative config.
3019+
"""Extract Eagle3 auxiliary layer indices from speculative config.
3020+
3021+
These indices specify which hidden states from the base model should
3022+
be used as auxiliary inputs for the Eagle3 drafter model during
3023+
speculative decoding.
30253024
30263025
Returns:
30273026
Tuple of layer indices if found in draft model config,
@@ -3031,18 +3030,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
30313030
and self.speculative_config.draft_model_config):
30323031
return None
30333032

3034-
try:
3035-
hf_config = self.speculative_config.draft_model_config.hf_config
3036-
if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'):
3037-
return None
3038-
3039-
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
3040-
if layer_ids and isinstance(layer_ids, (list, tuple)):
3041-
return tuple(layer_ids)
3042-
except Exception as e:
3043-
logger.warning(
3044-
"Failed to read auxiliary layers from speculative config: %s",
3045-
e)
3033+
hf_config = self.speculative_config.draft_model_config.hf_config
3034+
if not hasattr(hf_config, 'eagle_aux_hidden_state_layer_ids'):
3035+
return None
3036+
3037+
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
3038+
if layer_ids and isinstance(layer_ids, (list, tuple)):
3039+
return tuple(layer_ids)
30463040

30473041
return None
30483042

0 commit comments

Comments
 (0)