@@ -2753,14 +2753,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
2753
2753
2754
2754
# Try to get auxiliary layers from speculative config,
2755
2755
# otherwise use model's default layers
2756
- aux_layers = (self ._get_eagle3_aux_layers_from_config () or
2757
- self .model .get_eagle3_aux_hidden_state_layers ())
2758
-
2759
- if aux_layers != self .model .get_eagle3_aux_hidden_state_layers (
2760
- ):
2756
+ aux_layers = self ._get_eagle3_aux_layers_from_config ()
2757
+ if aux_layers :
2761
2758
logger .info (
2762
2759
"Using auxiliary layers from speculative config: %s" ,
2763
2760
aux_layers )
2761
+ else :
2762
+ aux_layers = self .model .get_eagle3_aux_hidden_state_layers (
2763
+ )
2764
2764
2765
2765
self .model .set_aux_hidden_state_layers (aux_layers )
2766
2766
time_after_load = time .perf_counter ()
@@ -2814,7 +2814,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
2814
2814
CUDAGraphMode .NONE , self .device )
2815
2815
2816
2816
def _get_eagle3_aux_layers_from_config (self ) -> Optional [tuple [int , ...]]:
2817
- """Extract Eagle3 auxiliary layer IDs from speculative config.
2817
+ """Extract Eagle3 auxiliary layer indices from speculative config.
2818
+
2819
+ These indices specify which hidden states from the base model should
2820
+ be used as auxiliary inputs for the Eagle3 drafter model during
2821
+ speculative decoding.
2818
2822
2819
2823
Returns:
2820
2824
Tuple of layer indices if found in draft model config,
@@ -2824,18 +2828,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
2824
2828
and self .speculative_config .draft_model_config ):
2825
2829
return None
2826
2830
2827
- try :
2828
- hf_config = self .speculative_config .draft_model_config .hf_config
2829
- if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2830
- return None
2831
-
2832
- layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2833
- if layer_ids and isinstance (layer_ids , (list , tuple )):
2834
- return tuple (layer_ids )
2835
- except Exception as e :
2836
- logger .warning (
2837
- "Failed to read auxiliary layers from speculative config: %s" ,
2838
- e )
2831
+ hf_config = self .speculative_config .draft_model_config .hf_config
2832
+ if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
2833
+ return None
2834
+
2835
+ layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
2836
+ if layer_ids and isinstance (layer_ids , (list , tuple )):
2837
+ return tuple (layer_ids )
2839
2838
2840
2839
return None
2841
2840
0 commit comments