@@ -914,55 +914,6 @@ def _get_eagle_module_inputs(
914914 eagle_inputs ["attention_mask" ] = attn_mask
915915 eagle_inputs ["position_ids" ] = position_ids
916916 eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
917-
918- if self .config .sequence_parallel :
919- gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
920- else :
921- gathered_hidden_states = hidden_states
922- eagle_inputs ["hidden_states" ] = gathered_hidden_states
923-
924- for i in range (self .eagle_config .parallel_draft_step - 1 ):
925- eagle_inputs ["input_ids" ] = torch .cat (
926- (
927- eagle_inputs ["input_ids" ],
928- torch .full (
929- padded_input_ids .shape ,
930- getattr (self , f"mask_token_{ i } " ),
931- device = padded_input_ids .device ,
932- dtype = padded_input_ids .dtype ,
933- ),
934- ),
935- dim = - 1 ,
936- )
937-
938- eagle_inputs ["hidden_states" ] = torch .cat (
939- (
940- eagle_inputs ["hidden_states" ],
941- torch .zeros (
942- (1 + i , b , h ), dtype = hidden_states .dtype , device = hidden_states .device
943- ),
944- gathered_hidden_states [: - (1 + i )],
945- ),
946- dim = 0 ,
947- )
948-
949- eagle_inputs ["position_ids" ] = torch .cat (
950- (eagle_inputs ["position_ids" ], position_ids ), dim = - 1
951- )
952-
953- if rotary_pos_emb is not None :
954- eagle_inputs ["rotary_pos_emb" ] = torch .cat (
955- (eagle_inputs ["rotary_pos_emb" ], rotary_pos_emb ), dim = 0
956- )
957-
958- if self .config .sequence_parallel :
959- eagle_inputs ["hidden_states" ] = scatter_to_sequence_parallel_region (
960- eagle_inputs ["hidden_states" ]
961- )
962-
963- eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
964- attn_mask , self .eagle_config .parallel_draft_step
965- )
966917 elif features .shape [0 ] == hidden_states .shape [0 ]:
967918 eagle_inputs ["input_ids" ] = torch .cat (
968919 (padded_input_ids , padded_input_ids ),
0 commit comments