@@ -781,16 +781,6 @@ def _get_eagle_module_inputs(
781781 eagle_inputs = {}
782782
783783 eagle_inputs ["position_ids" ] = position_ids
784- eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
785-
786- if self .config .sequence_parallel :
787- gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
788- gathered_features = (
789- None if features is None else gather_from_sequence_parallel_region (features )
790- )
791- else :
792- gathered_hidden_states = hidden_states
793- gathered_features = features
794784
795785 eagle_inputs ["input_ids" ] = (
796786 padded_input_ids
@@ -803,6 +793,14 @@ def _get_eagle_module_inputs(
803793 )
804794 )
805795
796+ if self .config .sequence_parallel :
797+ gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
798+ gathered_features = (
799+ None if features is None else gather_from_sequence_parallel_region (features )
800+ )
801+ else :
802+ gathered_hidden_states = hidden_states
803+ gathered_features = features
806804 if gathered_features is not None :
807805 feature = gathered_features [- s :]
808806 eagle_inputs ["hidden_states" ] = (
@@ -829,6 +827,12 @@ def _get_eagle_module_inputs(
829827 attn_mask , (ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step
830828 )
831829
830+ eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831+ [rotary_pos_emb ]
832+ * ((ttt_step - 1 ) * self .eagle_config .parallel_draft_step + parallel_draft_step ),
833+ dim = 0 ,
834+ )
835+
832836 eagle_inputs ["embedding" ] = self .embedding (
833837 input_ids = eagle_inputs ["input_ids" ],
834838 position_ids = eagle_inputs ["position_ids" ],
0 commit comments