@@ -781,16 +781,6 @@ def _get_eagle_module_inputs(
781
781
eagle_inputs = {}
782
782
783
783
eagle_inputs ["position_ids" ] = position_ids
784
- eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
785
-
786
- if self .config .sequence_parallel :
787
- gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
788
- gathered_features = (
789
- None if features is None else gather_from_sequence_parallel_region (features )
790
- )
791
- else :
792
- gathered_hidden_states = hidden_states
793
- gathered_features = features
794
784
795
785
eagle_inputs ["input_ids" ] = (
796
786
padded_input_ids
@@ -803,6 +793,14 @@ def _get_eagle_module_inputs(
803
793
)
804
794
)
805
795
796
+ if self .config .sequence_parallel :
797
+ gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
798
+ gathered_features = (
799
+ None if features is None else gather_from_sequence_parallel_region (features )
800
+ )
801
+ else :
802
+ gathered_hidden_states = hidden_states
803
+ gathered_features = features
806
804
if gathered_features is not None :
807
805
feature = gathered_features [- s :]
808
806
eagle_inputs ["hidden_states" ] = (
@@ -829,6 +827,12 @@ def _get_eagle_module_inputs(
829
827
attn_mask , (ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step
830
828
)
831
829
830
+ eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831
+ [rotary_pos_emb ]
832
+ * ((ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step ),
833
+ dim = 0 ,
834
+ )
835
+
832
836
eagle_inputs ["embedding" ] = self .embedding (
833
837
input_ids = eagle_inputs ["input_ids" ],
834
838
position_ids = eagle_inputs ["position_ids" ],
0 commit comments