@@ -780,15 +780,9 @@ def _get_eagle_module_inputs(
780780
781781 eagle_inputs = {}
782782
783- eagle_inputs ["input_ids" ] = torch .empty (
784- 0 , dtype = padded_input_ids .dtype , device = padded_input_ids .device
785- )
786- eagle_inputs ["position_ids" ] = torch .empty (
787- 0 , dtype = position_ids .dtype , device = position_ids .device
788- )
789- eagle_inputs ["rotary_pos_emb" ] = torch .empty (
790- 0 , dtype = rotary_pos_emb .dtype , device = rotary_pos_emb .device
791- )
783+ eagle_inputs ["position_ids" ] = position_ids
784+ eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
785+
792786 if self .config .sequence_parallel :
793787 gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
794788 gathered_features = (
@@ -797,56 +791,34 @@ def _get_eagle_module_inputs(
797791 else :
798792 gathered_hidden_states = hidden_states
799793 gathered_features = features
800- eagle_inputs ["hidden_states" ] = torch .empty (
801- 0 , dtype = gathered_hidden_states .dtype , device = gathered_hidden_states .device
802- )
803794
804- for step in range (ttt_step ):
805- for i in range (parallel_draft_step ):
806- eagle_inputs ["input_ids" ] = torch .cat (
807- (
808- eagle_inputs ["input_ids" ],
809- padded_input_ids
810- if i == 0
811- else torch .full (
812- padded_input_ids .shape ,
813- getattr (self , f"mask_token_{ i - 1 } " ),
814- device = padded_input_ids .device ,
815- dtype = padded_input_ids .dtype ,
816- ),
817- ),
818- dim = - 1 ,
819- )
795+ eagle_inputs ["input_ids" ] = (
796+ padded_input_ids
797+ if parallel_draft_step == 1
798+ else torch .full (
799+ padded_input_ids .shape ,
800+ getattr (self , f"mask_token_{ parallel_draft_step - 2 } " ),
801+ device = padded_input_ids .device ,
802+ dtype = padded_input_ids .dtype ,
803+ )
804+ )
820805
821- if step > 0 :
822- feature = gathered_features [- s :]
823- eagle_inputs ["hidden_states" ] = torch .cat (
824- (
825- eagle_inputs ["hidden_states" ],
826- gathered_hidden_states
827- if step == 0
828- else torch .cat (
829- (
830- torch .zeros (
831- (1 , b , h ),
832- dtype = hidden_states .dtype ,
833- device = hidden_states .device ,
834- ),
835- feature [:- 1 , :, :],
836- )
837- ),
806+ if gathered_features is not None :
807+ feature = gathered_features [- s :]
808+ eagle_inputs ["hidden_states" ] = (
809+ gathered_hidden_states
810+ if ttt_step == 1
811+ else torch .cat (
812+ (
813+ torch .zeros (
814+ (1 , b , h ),
815+ dtype = hidden_states .dtype ,
816+ device = hidden_states .device ,
838817 ),
839- dim = 0 ,
840- )
841-
842- eagle_inputs ["position_ids" ] = torch .cat (
843- (eagle_inputs ["position_ids" ], position_ids ), dim = - 1
818+ feature [:- 1 , :, :],
844819 )
845-
846- if rotary_pos_emb is not None :
847- eagle_inputs ["rotary_pos_emb" ] = torch .cat (
848- (eagle_inputs ["rotary_pos_emb" ], rotary_pos_emb ), dim = 0
849- )
820+ )
821+ )
850822
851823 if self .config .sequence_parallel :
852824 eagle_inputs ["hidden_states" ] = scatter_to_sequence_parallel_region (
0 commit comments