Skip to content

Commit 5c4c8a0

Browse files
committed
debug
Signed-off-by: Ye Yu <[email protected]>
1 parent 5b1ab4f commit 5c4c8a0

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -781,16 +781,6 @@ def _get_eagle_module_inputs(
781781
eagle_inputs = {}
782782

783783
eagle_inputs["position_ids"] = position_ids
784-
eagle_inputs["rotary_pos_emb"] = rotary_pos_emb
785-
786-
if self.config.sequence_parallel:
787-
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
788-
gathered_features = (
789-
None if features is None else gather_from_sequence_parallel_region(features)
790-
)
791-
else:
792-
gathered_hidden_states = hidden_states
793-
gathered_features = features
794784

795785
eagle_inputs["input_ids"] = (
796786
padded_input_ids
@@ -803,6 +793,14 @@ def _get_eagle_module_inputs(
803793
)
804794
)
805795

796+
if self.config.sequence_parallel:
797+
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
798+
gathered_features = (
799+
None if features is None else gather_from_sequence_parallel_region(features)
800+
)
801+
else:
802+
gathered_hidden_states = hidden_states
803+
gathered_features = features
806804
if gathered_features is not None:
807805
feature = gathered_features[-s:]
808806
eagle_inputs["hidden_states"] = (
@@ -829,6 +827,12 @@ def _get_eagle_module_inputs(
829827
attn_mask, (ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step
830828
)
831829

830+
eagle_inputs["rotary_pos_emb"] = torch.cat(
831+
[rotary_pos_emb]
832+
* ((ttt_step - 1) * self.eagle_config.parallel_draft_step + parallel_draft_step),
833+
dim=0,
834+
)
835+
832836
eagle_inputs["embedding"] = self.embedding(
833837
input_ids=eagle_inputs["input_ids"],
834838
position_ids=eagle_inputs["position_ids"],

0 commit comments

Comments
 (0)