@@ -480,6 +480,15 @@ def __init__(
480480 skip_weight_param_allocation = False ,
481481 )
482482
483+ # Set up learnable parallel draft embeddings and hidden_states
484+ if config .parallel_draft_step > 1 :
485+ self .parallel_draft_embeddings = torch .nn .Parameter (
486+ torch .rand (config .parallel_draft_step - 1 , config .hidden_size )
487+ )
488+ self .parallel_draft_hidden_states = torch .nn .Parameter (
489+ torch .rand (config .parallel_draft_step - 1 , config .hidden_size )
490+ )
491+
483492 def _get_eagle_transformer_layer_spec (self , config ):
484493 """Get the TransformerLayer implementation spec.
485494
@@ -755,19 +764,6 @@ def modify(
755764 # Eagle loss functions
756765 self .kld = logits_kld_loss
757766
758- # Set up learnable parallel draft embeddings and hidden_states
759- if self .eagle_config .parallel_draft_step > 1 :
760- self .parallel_draft_embeddings = torch .nn .Parameter (
761- torch .rand (
762- self .eagle_config .parallel_draft_step - 1 , self .eagle_config .hidden_size
763- )
764- )
765- self .parallel_draft_hidden_states = torch .nn .Parameter (
766- torch .rand (
767- self .eagle_config .parallel_draft_step - 1 , self .eagle_config .hidden_size
768- )
769- )
770-
771767 def _get_eagle_input_hidden_states (self , hidden_states : torch .Tensor , apply_fc : bool = True ):
772768 """When _aux_hidden_states is not empty for online, then this is EAGLE-3.
773769
@@ -824,10 +820,10 @@ def _get_eagle_module_inputs(
824820 )
825821 eagle_inputs ["hidden_states" ] = hidden_states
826822 else :
827- eagle_inputs ["embedding" ] = self .parallel_draft_embeddings [
823+ eagle_inputs ["embedding" ] = self .eagle_module . parallel_draft_embeddings [
828824 parallel_draft_index - 1
829825 ].repeat (hidden_states .shape [0 ], hidden_states .shape [1 ], 1 )
830- eagle_inputs ["hidden_states" ] = self .parallel_draft_hidden_states [
826+ eagle_inputs ["hidden_states" ] = self .eagle_module . parallel_draft_hidden_states [
831827 parallel_draft_index - 1
832828 ].repeat (hidden_states .shape [0 ], hidden_states .shape [1 ], 1 )
833829
@@ -1417,10 +1413,10 @@ def pseudo_speculative_generate(
14171413 # parallel_draft_embeddings and parallel_draft_hidden_states
14181414 gathered_embedding [
14191415 seq_len - self .eagle_config .parallel_draft_step + 1 : seq_len
1420- ] = self .parallel_draft_embeddings .unsqueeze (1 )
1416+ ] = self .eagle_module . parallel_draft_embeddings .unsqueeze (1 )
14211417 padded_hidden_states [
14221418 seq_len - self .eagle_config .parallel_draft_step + 1 : seq_len
1423- ] = self .parallel_draft_hidden_states .unsqueeze (1 )
1419+ ] = self .eagle_module . parallel_draft_hidden_states .unsqueeze (1 )
14241420 if self .config .sequence_parallel :
14251421 padded_hidden_states = scatter_to_sequence_parallel_region (padded_hidden_states )
14261422 embeddings = scatter_to_sequence_parallel_region (gathered_embedding )
0 commit comments