Skip to content

Commit d6f1489

Browse files
committed
change variable name to make it clear
Signed-off-by: Ye Yu <[email protected]>
1 parent b88b27a commit d6f1489

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,7 @@ def _get_eagle_module_inputs(
760760
position_ids: torch.Tensor,
761761
features: torch.Tensor | None = None,
762762
ttt_step: int = 0,
763-
parallel_draft_step: int = 0,
763+
parallel_draft_index: int = 0,
764764
):
765765
"""Getting EAGLE module inputs."""
766766
b = hidden_states.shape[1]
@@ -784,10 +784,10 @@ def _get_eagle_module_inputs(
784784

785785
eagle_inputs["input_ids"] = (
786786
padded_input_ids
787-
if parallel_draft_step == 0
787+
if parallel_draft_index == 0
788788
else torch.full(
789789
padded_input_ids.shape,
790-
getattr(self, f"mask_token_{parallel_draft_step - 1}"),
790+
getattr(self, f"mask_token_{parallel_draft_index - 1}"),
791791
device=padded_input_ids.device,
792792
dtype=padded_input_ids.dtype,
793793
)
@@ -824,12 +824,12 @@ def _get_eagle_module_inputs(
824824
)
825825

826826
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(
827-
attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step
827+
attn_mask, ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index
828828
)
829829

830830
eagle_inputs["rotary_pos_emb"] = torch.cat(
831831
[rotary_pos_emb]
832-
* (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_step + 1),
832+
* (ttt_step * self.eagle_config.parallel_draft_step + parallel_draft_index + 1),
833833
dim=0,
834834
)
835835

@@ -1075,7 +1075,7 @@ def forward(
10751075
position_ids=position_ids,
10761076
features=eagle_hidden_states_pre_norm,
10771077
ttt_step=ttt_step,
1078-
parallel_draft_step=i,
1078+
parallel_draft_index=i,
10791079
)
10801080

10811081
_, eagle_logits_, eagle_hidden_states_pre_norm_ = self._eagle_forward(

0 commit comments

Comments
 (0)