Skip to content

Commit 5b1ab4f

Browse files
committed
debug
Signed-off-by: Ye Yu <[email protected]>
1 parent 651cbf4 commit 5b1ab4f

File tree

1 file changed

+27
-55
lines changed

1 file changed

+27
-55
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -780,15 +780,9 @@ def _get_eagle_module_inputs(
780780

781781
eagle_inputs = {}
782782

783-
eagle_inputs["input_ids"] = torch.empty(
784-
0, dtype=padded_input_ids.dtype, device=padded_input_ids.device
785-
)
786-
eagle_inputs["position_ids"] = torch.empty(
787-
0, dtype=position_ids.dtype, device=position_ids.device
788-
)
789-
eagle_inputs["rotary_pos_emb"] = torch.empty(
790-
0, dtype=rotary_pos_emb.dtype, device=rotary_pos_emb.device
791-
)
783+
eagle_inputs["position_ids"] = position_ids
784+
eagle_inputs["rotary_pos_emb"] = rotary_pos_emb
785+
792786
if self.config.sequence_parallel:
793787
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
794788
gathered_features = (
@@ -797,56 +791,34 @@ def _get_eagle_module_inputs(
797791
else:
798792
gathered_hidden_states = hidden_states
799793
gathered_features = features
800-
eagle_inputs["hidden_states"] = torch.empty(
801-
0, dtype=gathered_hidden_states.dtype, device=gathered_hidden_states.device
802-
)
803794

804-
for step in range(ttt_step):
805-
for i in range(parallel_draft_step):
806-
eagle_inputs["input_ids"] = torch.cat(
807-
(
808-
eagle_inputs["input_ids"],
809-
padded_input_ids
810-
if i == 0
811-
else torch.full(
812-
padded_input_ids.shape,
813-
getattr(self, f"mask_token_{i - 1}"),
814-
device=padded_input_ids.device,
815-
dtype=padded_input_ids.dtype,
816-
),
817-
),
818-
dim=-1,
819-
)
795+
eagle_inputs["input_ids"] = (
796+
padded_input_ids
797+
if parallel_draft_step == 1
798+
else torch.full(
799+
padded_input_ids.shape,
800+
getattr(self, f"mask_token_{parallel_draft_step - 2}"),
801+
device=padded_input_ids.device,
802+
dtype=padded_input_ids.dtype,
803+
)
804+
)
820805

821-
if step > 0:
822-
feature = gathered_features[-s:]
823-
eagle_inputs["hidden_states"] = torch.cat(
824-
(
825-
eagle_inputs["hidden_states"],
826-
gathered_hidden_states
827-
if step == 0
828-
else torch.cat(
829-
(
830-
torch.zeros(
831-
(1, b, h),
832-
dtype=hidden_states.dtype,
833-
device=hidden_states.device,
834-
),
835-
feature[:-1, :, :],
836-
)
837-
),
806+
if gathered_features is not None:
807+
feature = gathered_features[-s:]
808+
eagle_inputs["hidden_states"] = (
809+
gathered_hidden_states
810+
if ttt_step == 1
811+
else torch.cat(
812+
(
813+
torch.zeros(
814+
(1, b, h),
815+
dtype=hidden_states.dtype,
816+
device=hidden_states.device,
838817
),
839-
dim=0,
840-
)
841-
842-
eagle_inputs["position_ids"] = torch.cat(
843-
(eagle_inputs["position_ids"], position_ids), dim=-1
818+
feature[:-1, :, :],
844819
)
845-
846-
if rotary_pos_emb is not None:
847-
eagle_inputs["rotary_pos_emb"] = torch.cat(
848-
(eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0
849-
)
820+
)
821+
)
850822

851823
if self.config.sequence_parallel:
852824
eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region(

0 commit comments

Comments
 (0)