@@ -780,15 +780,9 @@ def _get_eagle_module_inputs(
780
780
781
781
eagle_inputs = {}
782
782
783
- eagle_inputs ["input_ids" ] = torch .empty (
784
- 0 , dtype = padded_input_ids .dtype , device = padded_input_ids .device
785
- )
786
- eagle_inputs ["position_ids" ] = torch .empty (
787
- 0 , dtype = position_ids .dtype , device = position_ids .device
788
- )
789
- eagle_inputs ["rotary_pos_emb" ] = torch .empty (
790
- 0 , dtype = rotary_pos_emb .dtype , device = rotary_pos_emb .device
791
- )
783
+ eagle_inputs ["position_ids" ] = position_ids
784
+ eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
785
+
792
786
if self .config .sequence_parallel :
793
787
gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
794
788
gathered_features = (
@@ -797,56 +791,34 @@ def _get_eagle_module_inputs(
797
791
else :
798
792
gathered_hidden_states = hidden_states
799
793
gathered_features = features
800
- eagle_inputs ["hidden_states" ] = torch .empty (
801
- 0 , dtype = gathered_hidden_states .dtype , device = gathered_hidden_states .device
802
- )
803
794
804
- for step in range (ttt_step ):
805
- for i in range (parallel_draft_step ):
806
- eagle_inputs ["input_ids" ] = torch .cat (
807
- (
808
- eagle_inputs ["input_ids" ],
809
- padded_input_ids
810
- if i == 0
811
- else torch .full (
812
- padded_input_ids .shape ,
813
- getattr (self , f"mask_token_{ i - 1 } " ),
814
- device = padded_input_ids .device ,
815
- dtype = padded_input_ids .dtype ,
816
- ),
817
- ),
818
- dim = - 1 ,
819
- )
795
+ eagle_inputs ["input_ids" ] = (
796
+ padded_input_ids
797
+ if parallel_draft_step == 1
798
+ else torch .full (
799
+ padded_input_ids .shape ,
800
+ getattr (self , f"mask_token_{ parallel_draft_step - 2 } " ),
801
+ device = padded_input_ids .device ,
802
+ dtype = padded_input_ids .dtype ,
803
+ )
804
+ )
820
805
821
- if step > 0 :
822
- feature = gathered_features [- s :]
823
- eagle_inputs ["hidden_states" ] = torch .cat (
824
- (
825
- eagle_inputs ["hidden_states" ],
826
- gathered_hidden_states
827
- if step == 0
828
- else torch .cat (
829
- (
830
- torch .zeros (
831
- (1 , b , h ),
832
- dtype = hidden_states .dtype ,
833
- device = hidden_states .device ,
834
- ),
835
- feature [:- 1 , :, :],
836
- )
837
- ),
806
+ if gathered_features is not None :
807
+ feature = gathered_features [- s :]
808
+ eagle_inputs ["hidden_states" ] = (
809
+ gathered_hidden_states
810
+ if ttt_step == 1
811
+ else torch .cat (
812
+ (
813
+ torch .zeros (
814
+ (1 , b , h ),
815
+ dtype = hidden_states .dtype ,
816
+ device = hidden_states .device ,
838
817
),
839
- dim = 0 ,
840
- )
841
-
842
- eagle_inputs ["position_ids" ] = torch .cat (
843
- (eagle_inputs ["position_ids" ], position_ids ), dim = - 1
818
+ feature [:- 1 , :, :],
844
819
)
845
-
846
- if rotary_pos_emb is not None :
847
- eagle_inputs ["rotary_pos_emb" ] = torch .cat (
848
- (eagle_inputs ["rotary_pos_emb" ], rotary_pos_emb ), dim = 0
849
- )
820
+ )
821
+ )
850
822
851
823
if self .config .sequence_parallel :
852
824
eagle_inputs ["hidden_states" ] = scatter_to_sequence_parallel_region (
0 commit comments