@@ -2736,14 +2736,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
2736
2736
2737
2737
# Try to get auxiliary layers from speculative config,
2738
2738
# otherwise use model's default layers
2739
- aux_layers = (self ._get_eagle3_aux_layers_from_config () or
2740
- self .model .get_eagle3_aux_hidden_state_layers ())
2741
-
2742
- if aux_layers != self .model .get_eagle3_aux_hidden_state_layers (
2743
- ):
2739
+ aux_layers = self ._get_eagle3_aux_layers_from_config ()
2740
+ if aux_layers :
2744
2741
logger .info (
2745
2742
"Using auxiliary layers from speculative config: %s" ,
2746
2743
aux_layers )
2744
+ else :
2745
+ aux_layers = self .model .get_eagle3_aux_hidden_state_layers (
2746
+ )
2747
2747
2748
2748
self .model .set_aux_hidden_state_layers (aux_layers )
2749
2749
time_after_load = time .perf_counter ()
@@ -2797,7 +2797,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
2797
2797
CUDAGraphMode .NONE , self .device )
2798
2798
2799
2799
def _get_eagle3_aux_layers_from_config (self ) -> Optional [tuple [int , ...]]:
2800
- """Extract Eagle3 auxiliary layer IDs from speculative config.
2800
+ """Extract Eagle3 auxiliary layer indices from speculative config.
2801
+
2802
+ These indices specify which hidden states from the base model should
2803
+ be used as auxiliary inputs for the Eagle3 drafter model during
2804
+ speculative decoding.
2801
2805
2802
2806
Returns:
2803
2807
Tuple of layer indices if found in draft model config,
@@ -2807,18 +2811,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
2807
2811
and self .speculative_config .draft_model_config ):
2808
2812
return None
2809
2813
2810
- try :
2811
- hf_config = self .speculative_config .draft_model_config .hf_config
2812
- if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2813
- return None
2814
-
2815
- layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2816
- if layer_ids and isinstance (layer_ids , (list , tuple )):
2817
- return tuple (layer_ids )
2818
- except Exception as e :
2819
- logger .warning (
2820
- "Failed to read auxiliary layers from speculative config: %s" ,
2821
- e )
2814
+ hf_config = self .speculative_config .draft_model_config .hf_config
2815
+ if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2816
+ return None
2817
+
2818
+ layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2819
+ if layer_ids and isinstance (layer_ids , (list , tuple )):
2820
+ return tuple (layer_ids )
2822
2821
2823
2822
return None
2824
2823
0 commit comments