@@ -760,7 +760,7 @@ def _get_eagle_module_inputs(
760760 position_ids : torch .Tensor ,
761761 features : torch .Tensor | None = None ,
762762 ttt_step : int = 0 ,
763- parallel_draft_step : int = 0 ,
763+ parallel_draft_index : int = 0 ,
764764 ):
765765 """Getting EAGLE module inputs."""
766766 b = hidden_states .shape [1 ]
@@ -784,10 +784,10 @@ def _get_eagle_module_inputs(
784784
785785 eagle_inputs ["input_ids" ] = (
786786 padded_input_ids
787- if parallel_draft_step == 0
787+ if parallel_draft_index == 0
788788 else torch .full (
789789 padded_input_ids .shape ,
790- getattr (self , f"mask_token_{ parallel_draft_step - 1 } " ),
790+ getattr (self , f"mask_token_{ parallel_draft_index - 1 } " ),
791791 device = padded_input_ids .device ,
792792 dtype = padded_input_ids .dtype ,
793793 )
@@ -824,12 +824,12 @@ def _get_eagle_module_inputs(
824824 )
825825
826826 eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
827- attn_mask , ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step
827+ attn_mask , ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_index
828828 )
829829
830830 eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831831 [rotary_pos_emb ]
832- * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step + 1 ),
832+ * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_index + 1 ),
833833 dim = 0 ,
834834 )
835835
@@ -1075,7 +1075,7 @@ def forward(
10751075 position_ids = position_ids ,
10761076 features = eagle_hidden_states_pre_norm ,
10771077 ttt_step = ttt_step ,
1078- parallel_draft_step = i ,
1078+ parallel_draft_index = i ,
10791079 )
10801080
10811081 _ , eagle_logits_ , eagle_hidden_states_pre_norm_ = self ._eagle_forward (
0 commit comments