Skip to content

Commit acd0aa6

Browse files
committed
debug
Signed-off-by: Ye Yu <[email protected]>
1 parent e646de9 commit acd0aa6

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,9 @@ def _get_eagle_module_inputs(
790790
eagle_inputs["position_ids"] = torch.empty(
791791
0, dtype=position_ids.dtype, device=position_ids.device
792792
)
793-
eagle_inputs["rotary_pos_emb"] = rotary_pos_emb
793+
eagle_inputs["rotary_pos_emb"] = torch.empty(
794+
0, dtype=rotary_pos_emb.dtype, device=rotary_pos_emb.device
795+
)
794796
if self.config.sequence_parallel:
795797
gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states)
796798
gathered_features = (
@@ -831,13 +833,15 @@ def _get_eagle_module_inputs(
831833
eagle_inputs["hidden_states"],
832834
gathered_hidden_states
833835
if step == 0
834-
else (
835-
torch.zeros(
836-
(1, b, h),
837-
dtype=hidden_states.dtype,
838-
device=hidden_states.device,
839-
),
840-
feature[:-1, :, :],
836+
else torch.cat(
837+
(
838+
torch.zeros(
839+
(1, b, h),
840+
dtype=hidden_states.dtype,
841+
device=hidden_states.device,
842+
),
843+
feature[:-1, :, :],
844+
)
841845
),
842846
),
843847
dim=0,

0 commit comments

Comments
 (0)