@@ -404,17 +404,31 @@ def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None
404404 )
405405 self ._aux_hidden_states .append (hidden_states )
406406
407- def pop_aux_hidden_states (self ):
408- """Return aux hidden states from base model, and clear the list ."""
407+ def pop_and_gather_aux_hiddens (self ):
408+ """Pop auxiliary hidden states from base model and gather them on the draft model device ."""
409409 # In PTQ, forward method will be called with try and except to find max batch size.
410410 # This leads to uncleared aux hidden states in the front of the list.
411411 # To fix it, we only return the last num_aux_h items in the list.
412412 num_aux_h = len (self .eagle_config .eagle_aux_hidden_state_layer_ids )
413413 aux_h_list = self ._aux_hidden_states [- num_aux_h :]
414414 self ._aux_hidden_states .clear ()
415415
416+ # Gather aux hidden states on the draft model device
417+ aux_h_list = [h .to (self .eagle_module .fc .weight .device ) for h in aux_h_list ]
418+
416419 return aux_h_list
417420
421+ def _get_eagle_device (self ):
422+ """Return the device where we should place eagle module."""
423+ if self .eagle_offline :
424+ # For offline training, the base model has no layers.
425+ # Read the device from the base model lm_head instead.
426+ return self ._base_model_lm_head .weight .device
427+ else :
428+ # When there is a base model, put eagle on the last layer's device.
429+ base_model_last_layer = self ._base_model .layers [- 1 ]
430+ return next (base_model_last_layer .parameters ()).device
431+
418432 def modify (
419433 self ,
420434 eagle_offline ,
@@ -469,7 +483,7 @@ def modify(
469483
470484 # find base model, lm head, and embeddings paths
471485 self ._find_base_model_parts ()
472- self .eagle_module .to (self ._base_model .dtype ).to (self ._base_model_lm_head . weight . device )
486+ self .eagle_module .to (self ._base_model .dtype ).to (self ._get_eagle_device () )
473487
474488 # Make sure word embedding and lm head are frozen
475489 for param in self ._base_model_embeddings .parameters ():
@@ -777,52 +791,52 @@ def forward(
777791 # ====Run eagle forward====
778792 eagle_loss = None
779793 train_accs = []
780- if self .training :
781- # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers
782- b , seq_length , h = base_model_hidden_states .shape
783- if self .eagle_config .use_aux_hidden_state :
784- if "base_model_outputs" in kwargs :
785- aux_hidden_states = kwargs ["base_model_outputs" ]["aux_hidden_states" ]
786- else :
787- aux_hidden_states = torch .cat (self .pop_aux_hidden_states (), dim = - 1 )
788- eagle_input_hidden_states = self .eagle_module .fc (aux_hidden_states )
794+ # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers
795+ b , seq_length , h = base_model_hidden_states .shape
796+ if self .eagle_config .use_aux_hidden_state :
797+ if "base_model_outputs" in kwargs :
798+ aux_hidden_states = kwargs ["base_model_outputs" ]["aux_hidden_states" ]
789799 else :
790- eagle_input_hidden_states = base_model_hidden_states
800+ aux_hidden_states = torch .cat (self .pop_and_gather_aux_hiddens (), dim = - 1 )
801+ eagle_input_hidden_states = self .eagle_module .fc (aux_hidden_states )
802+ else :
803+ eagle_input_hidden_states = base_model_hidden_states
791804
792- # Get eagle inputs for the first eagle forward pass
793- eagle_input_ids , attention_mask_0 , position_ids = self ._get_eagle_module_inputs (
794- input_ids ,
795- eagle_input_hidden_states ,
796- attention_mask ,
797- position_ids ,
798- eagle_cache ,
799- )
800- with torch .no_grad ():
801- inputs_embeds = self ._base_model_embeddings (eagle_input_ids )
802- position_embeddings = self .eagle_rotary_emb (eagle_input_hidden_states , position_ids )
805+ # Get eagle inputs for the first eagle forward pass
806+ eagle_input_ids , attention_mask_0 , position_ids = self ._get_eagle_module_inputs (
807+ input_ids ,
808+ eagle_input_hidden_states ,
809+ attention_mask ,
810+ position_ids ,
811+ eagle_cache ,
812+ )
813+ with torch .no_grad ():
814+ inputs_embeds = self ._base_model_embeddings (eagle_input_ids )
815+ position_embeddings = self .eagle_rotary_emb (eagle_input_hidden_states , position_ids )
803816
804- # Then, we run eagle forward
805- _ , eagle_prenorm_h , eagle_logits , eagle_cache = self ._eagle_forward (
806- eagle_input_hidden_states ,
807- inputs_embeds ,
808- attention_mask_0 ,
809- position_ids ,
810- position_embeddings ,
811- eagle_cache ,
812- )
817+ # Then, we run eagle forward
818+ _ , eagle_prenorm_h , eagle_logits , eagle_cache = self ._eagle_forward (
819+ eagle_input_hidden_states ,
820+ inputs_embeds ,
821+ attention_mask_0 ,
822+ position_ids ,
823+ position_embeddings ,
824+ eagle_cache ,
825+ )
813826
814- past_key_values .eagle_cache = eagle_cache
827+ past_key_values .eagle_cache = eagle_cache
815828
816- # Compute loss on the eagle modules
817- classification_loss , acc = self ._eagle_loss (
818- base_model_logits [:, 1 :],
819- eagle_logits [:, :- 1 ],
820- loss_mask [:, 1 :],
821- )
822- eagle_loss = classification_loss
823- train_accs .append (acc )
829+ # Compute loss on the eagle modules
830+ classification_loss , acc = self ._eagle_loss (
831+ base_model_logits [:, 1 :],
832+ eagle_logits [:, :- 1 ],
833+ loss_mask [:, 1 :],
834+ )
835+ eagle_loss = classification_loss
836+ train_accs .append (acc )
824837
825- # ====Perform training-time-testing with 3 extra eagle forward passes====
838+ # ====Perform training-time-testing with 3 extra eagle forward passes====
839+ if self .training :
826840 for ttt_step in range (self .num_ttt_steps ):
827841 eagle_input_hidden_states = torch .cat (
828842 (
@@ -931,7 +945,7 @@ def pseudo_speculative_generate(
931945 # Early return
932946 if steps < 1 :
933947 if hasattr (self , "_aux_hidden_states" ):
934- _ = self .pop_aux_hidden_states ()
948+ _ = self .pop_and_gather_aux_hiddens ()
935949 return base_token , None
936950
937951 eagle_ids = torch .cat ((input_ids [:, 1 :], base_token ), dim = - 1 )
@@ -940,10 +954,7 @@ def pseudo_speculative_generate(
940954 # EAGLE-3
941955 # Only the first iteration input_hidden_states are from aux_hidden_state layers
942956 # Gather _aux_hidden_states from all devices before concatenation
943- gathered_aux_hidden_states = self .pop_aux_hidden_states ()
944- gathered_aux_hidden_states = [
945- h .to (input_ids .device ) for h in gathered_aux_hidden_states
946- ]
957+ gathered_aux_hidden_states = self .pop_and_gather_aux_hiddens ()
947958 eagle_input_hidden_states = self .eagle_module .fc (
948959 torch .cat (gathered_aux_hidden_states , dim = - 1 )
949960 )
0 commit comments