Skip to content

Commit b8cc0de

Browse files
committed
revert
Signed-off-by: Ye Yu <[email protected]>
1 parent e12c7c8 commit b8cc0de

File tree

1 file changed

+0
-49
lines changed

1 file changed

+0
-49
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -914,55 +914,6 @@ def _get_eagle_module_inputs(
914914
eagle_inputs["attention_mask"] = attn_mask
915915
eagle_inputs["position_ids"] = position_ids
916916
eagle_inputs["rotary_pos_emb"] = rotary_pos_emb
917-
918-
if self.config.sequence_parallel:
919-
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
920-
else:
921-
gathered_hidden_states = hidden_states
922-
eagle_inputs["hidden_states"] = gathered_hidden_states
923-
924-
for i in range(self.eagle_config.parallel_draft_step - 1):
925-
eagle_inputs["input_ids"] = torch.cat(
926-
(
927-
eagle_inputs["input_ids"],
928-
torch.full(
929-
padded_input_ids.shape,
930-
getattr(self, f"mask_token_{i}"),
931-
device=padded_input_ids.device,
932-
dtype=padded_input_ids.dtype,
933-
),
934-
),
935-
dim=-1,
936-
)
937-
938-
eagle_inputs["hidden_states"] = torch.cat(
939-
(
940-
eagle_inputs["hidden_states"],
941-
torch.zeros(
942-
(1 + i, b, h), dtype=hidden_states.dtype, device=hidden_states.device
943-
),
944-
gathered_hidden_states[: -(1 + i)],
945-
),
946-
dim=0,
947-
)
948-
949-
eagle_inputs["position_ids"] = torch.cat(
950-
(eagle_inputs["position_ids"], position_ids), dim=-1
951-
)
952-
953-
if rotary_pos_emb is not None:
954-
eagle_inputs["rotary_pos_emb"] = torch.cat(
955-
(eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0
956-
)
957-
958-
if self.config.sequence_parallel:
959-
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region(
960-
eagle_inputs["hidden_states"]
961-
)
962-
963-
eagle_inputs["attention_mask"] = set_multi_step_attention_mask(
964-
attn_mask, self.eagle_config.parallel_draft_step
965-
)
966917
elif features.shape[0] == hidden_states.shape[0]:
967918
eagle_inputs["input_ids"] = torch.cat(
968919
(padded_input_ids, padded_input_ids),

0 commit comments

Comments
 (0)