@@ -2951,19 +2951,14 @@ def load_model(self, eep_scale_up: bool = False) -> None:
2951
2951
2952
2952
# Try to get auxiliary layers from speculative config,
2953
2953
# otherwise use model's default layers
2954
- aux_layers = (
2955
- self ._get_eagle3_aux_layers_from_config ()
2956
- or self .model .get_eagle3_aux_hidden_state_layers ()
2957
- )
2958
-
2959
- if (
2960
- aux_layers
2961
- != self .model .get_eagle3_aux_hidden_state_layers ()
2962
- ):
2954
+ aux_layers = self ._get_eagle3_aux_layers_from_config ()
2955
+ if aux_layers :
2963
2956
logger .info (
2964
2957
"Using auxiliary layers from speculative config: %s" ,
2965
2958
aux_layers ,
2966
2959
)
2960
+ else :
2961
+ aux_layers = self .model .get_eagle3_aux_hidden_state_layers ()
2967
2962
2968
2963
self .model .set_aux_hidden_state_layers (aux_layers )
2969
2964
time_after_load = time .perf_counter ()
@@ -3021,7 +3016,11 @@ def load_model(self, eep_scale_up: bool = False) -> None:
3021
3016
)
3022
3017
3023
3018
def _get_eagle3_aux_layers_from_config (self ) -> Optional [tuple [int , ...]]:
3024
- """Extract Eagle3 auxiliary layer IDs from speculative config.
3019
+ """Extract Eagle3 auxiliary layer indices from speculative config.
3020
+
3021
+ These indices specify which hidden states from the base model should
3022
+ be used as auxiliary inputs for the Eagle3 drafter model during
3023
+ speculative decoding.
3025
3024
3026
3025
Returns:
3027
3026
Tuple of layer indices if found in draft model config,
@@ -3031,18 +3030,13 @@ def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
3031
3030
and self .speculative_config .draft_model_config ):
3032
3031
return None
3033
3032
3034
- try :
3035
- hf_config = self .speculative_config .draft_model_config .hf_config
3036
- if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
3037
- return None
3038
-
3039
- layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
3040
- if layer_ids and isinstance (layer_ids , (list , tuple )):
3041
- return tuple (layer_ids )
3042
- except Exception as e :
3043
- logger .warning (
3044
- "Failed to read auxiliary layers from speculative config: %s" ,
3045
- e )
3033
+ hf_config = self .speculative_config .draft_model_config .hf_config
3034
+ if not hasattr (hf_config , 'eagle_aux_hidden_state_layer_ids' ):
3035
+ return None
3036
+
3037
+ layer_ids = hf_config .eagle_aux_hidden_state_layer_ids
3038
+ if layer_ids and isinstance (layer_ids , (list , tuple )):
3039
+ return tuple (layer_ids )
3046
3040
3047
3041
return None
3048
3042
0 commit comments