@@ -674,15 +674,6 @@ def modify(
674674 "Only logit distillation is supported when draft_vocab_size != vocab_size!"
675675 )
676676
677- # Set up learnable parallel draft embeddings and hidden_states
678- if self .eagle_config .parallel_draft_step > 1 :
679- self .parallel_draft_embeddings = torch .nn .Parameter (
680- torch .rand (self .eagle_config .parallel_draft_step - 1 , self .eagle_config .hidden_size )
681- )
682- self .parallel_draft_hidden_states = torch .nn .Parameter (
683- torch .rand (self .eagle_config .parallel_draft_step - 1 , self .eagle_config .hidden_size )
684- )
685-
686677 # Use default aux_hidden_state layers if use_aux_hidden_state is True
687678 # but no layer id is given
688679 # layer ids are not used in offline eagle, but we need to set this to have correct fc_input_size_multiplier
@@ -764,6 +755,19 @@ def modify(
764755 # Eagle loss functions
765756 self .kld = logits_kld_loss
766757
758+ # Set up learnable parallel draft embeddings and hidden_states
759+ if self .eagle_config .parallel_draft_step > 1 :
760+ self .parallel_draft_embeddings = torch .nn .Parameter (
761+ torch .rand (
762+ self .eagle_config .parallel_draft_step - 1 , self .eagle_config .hidden_size
763+ )
764+ )
765+ self .parallel_draft_hidden_states = torch .nn .Parameter (
766+ torch .rand (
767+ self .eagle_config .parallel_draft_step - 1 , self .eagle_config .hidden_size
768+ )
769+ )
770+
767771 def _get_eagle_input_hidden_states (self , hidden_states : torch .Tensor , apply_fc : bool = True ):
768772 """When _aux_hidden_states is not empty for online, then this is EAGLE-3.
769773
@@ -1389,9 +1393,7 @@ def pseudo_speculative_generate(
13891393 for _ in range (self .eagle_config .parallel_draft_step - 1 ):
13901394 # Pad dummy eagle_ids and hidden_states for parallel draft
13911395 # They will be replaced by parallel draft embeddings and hidden_states after padding
1392- eagle_ids = torch .cat (
1393- (eagle_ids , torch .zeros (1 , 1 ).to (eagle_ids .dtype ).to (eagle_ids .device )), dim = - 1
1394- )
1396+ eagle_ids = torch .cat ((eagle_ids , eagle_ids [:, - 1 :]), dim = - 1 )
13951397 hidden_states = torch .cat ((hidden_states , hidden_states [- 1 :]), dim = 0 )
13961398 padded_eagle_ids , seq_len , padded_hidden_states = right_padding (
13971399 eagle_ids , hidden_states
0 commit comments