Skip to content

Commit d3e1e0a

Browse files
committed
debug
Signed-off-by: Ye Yu <[email protected]>
1 parent fcc38fe commit d3e1e0a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,9 @@ def pseudo_speculative_generate(
13891389
for _ in range(self.eagle_config.parallel_draft_step - 1):
13901390
# Pad dummy eagle_ids and hidden_states for parallel draft
13911391
# They will be replaced by parallel draft embeddings and hidden_states after padding
1392-
eagle_ids = torch.cat((eagle_ids, torch.zeros(1, 1).to(eagle_ids.device)), dim=-1)
1392+
eagle_ids = torch.cat(
1393+
(eagle_ids, torch.zeros(1, 1).to(eagle_ids.dtype).to(eagle_ids.device)), dim=-1
1394+
)
13931395
hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0)
13941396
padded_eagle_ids, seq_len, padded_hidden_states = right_padding(
13951397
eagle_ids, hidden_states

0 commit comments

Comments
 (0)