@@ -845,7 +845,7 @@ def _get_eagle_module_inputs(
845
845
rotary_pos_emb = self .eagle_module .rotary_pos_emb (padded_input_ids .shape [- 1 ])
846
846
847
847
attn_mask = attention_mask .clone ().detach ()
848
- attn_mask [:, :, :- 1 , :- 1 ] = attn_mask [:, :, 1 :, 1 :]
848
+ attn_mask [:, :, :- 1 , :- 1 ] = attention_mask [:, :, 1 :, 1 :]
849
849
attn_mask [:, :, - 1 , :] = True
850
850
attn_mask [:, :, :, - 1 ] = True
851
851
@@ -914,6 +914,55 @@ def _get_eagle_module_inputs(
914
914
eagle_inputs ["attention_mask" ] = attn_mask
915
915
eagle_inputs ["position_ids" ] = position_ids
916
916
eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
917
+
918
+ if self .config .sequence_parallel :
919
+ gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
920
+ else :
921
+ gathered_hidden_states = hidden_states
922
+ eagle_inputs ["hidden_states" ] = gathered_hidden_states
923
+
924
+ for i in range (self .eagle_config .parallel_draft_step - 1 ):
925
+ eagle_inputs ["input_ids" ] = torch .cat (
926
+ (
927
+ eagle_inputs ["input_ids" ],
928
+ torch .full (
929
+ padded_input_ids .shape ,
930
+ getattr (self , f"mask_token_{ i } " ),
931
+ device = padded_input_ids .device ,
932
+ dtype = padded_input_ids .dtype ,
933
+ ),
934
+ ),
935
+ dim = - 1 ,
936
+ )
937
+
938
+ eagle_inputs ["hidden_states" ] = torch .cat (
939
+ (
940
+ eagle_inputs ["hidden_states" ],
941
+ torch .zeros (
942
+ (1 + i , b , h ), dtype = hidden_states .dtype , device = hidden_states .device
943
+ ),
944
+ gathered_hidden_states [: - (1 + i )],
945
+ ),
946
+ dim = 0 ,
947
+ )
948
+
949
+ eagle_inputs ["position_ids" ] = torch .cat (
950
+ (eagle_inputs ["position_ids" ], position_ids ), dim = - 1
951
+ )
952
+
953
+ if rotary_pos_emb is not None :
954
+ eagle_inputs ["rotary_pos_emb" ] = torch .cat (
955
+ (eagle_inputs ["rotary_pos_emb" ], rotary_pos_emb ), dim = 0
956
+ )
957
+
958
+ if self .config .sequence_parallel :
959
+ eagle_inputs ["hidden_states" ] = scatter_to_sequence_parallel_region (
960
+ eagle_inputs ["hidden_states" ]
961
+ )
962
+
963
+ eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
964
+ attn_mask , self .eagle_config .parallel_draft_step
965
+ )
917
966
elif features .shape [0 ] == hidden_states .shape [0 ]:
918
967
eagle_inputs ["input_ids" ] = torch .cat (
919
968
(padded_input_ids , padded_input_ids ),
0 commit comments