@@ -760,7 +760,7 @@ def _get_eagle_module_inputs(
760
760
position_ids : torch .Tensor ,
761
761
features : torch .Tensor | None = None ,
762
762
ttt_step : int = 0 ,
763
- parallel_draft_step : int = 0 ,
763
+ parallel_draft_index : int = 0 ,
764
764
):
765
765
"""Getting EAGLE module inputs."""
766
766
b = hidden_states .shape [1 ]
@@ -784,10 +784,10 @@ def _get_eagle_module_inputs(
784
784
785
785
eagle_inputs ["input_ids" ] = (
786
786
padded_input_ids
787
- if parallel_draft_step == 0
787
+ if parallel_draft_index == 0
788
788
else torch .full (
789
789
padded_input_ids .shape ,
790
- getattr (self , f"mask_token_{ parallel_draft_step - 1 } " ),
790
+ getattr (self , f"mask_token_{ parallel_draft_index - 1 } " ),
791
791
device = padded_input_ids .device ,
792
792
dtype = padded_input_ids .dtype ,
793
793
)
@@ -824,12 +824,12 @@ def _get_eagle_module_inputs(
824
824
)
825
825
826
826
eagle_inputs ["attention_mask" ] = set_multi_step_attention_mask (
827
- attn_mask , ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step
827
+ attn_mask , ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_index
828
828
)
829
829
830
830
eagle_inputs ["rotary_pos_emb" ] = torch .cat (
831
831
[rotary_pos_emb ]
832
- * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_step + 1 ),
832
+ * (ttt_step * self .eagle_config .parallel_draft_step + parallel_draft_index + 1 ),
833
833
dim = 0 ,
834
834
)
835
835
@@ -1075,7 +1075,7 @@ def forward(
1075
1075
position_ids = position_ids ,
1076
1076
features = eagle_hidden_states_pre_norm ,
1077
1077
ttt_step = ttt_step ,
1078
- parallel_draft_step = i ,
1078
+ parallel_draft_index = i ,
1079
1079
)
1080
1080
1081
1081
_ , eagle_logits_ , eagle_hidden_states_pre_norm_ = self ._eagle_forward (
0 commit comments