Skip to content

Commit 4542a6b

Browse files
committed
fix the bug in pseudo_speculative_generate
Signed-off-by: Ye Yu <[email protected]>
1 parent 8ad781a commit 4542a6b

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,11 @@ def pseudo_speculative_generate(
14081408

14091409
draft_tokens.append(draft_token)
14101410

1411+
# Remove mask tokens from eagle_ids before adding draft_token
1412+
# Remove added hidden_states before
1413+
eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1]
1414+
hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1]
1415+
14111416
eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1)
14121417
hidden_states = torch.cat(
14131418
(

0 commit comments

Comments
 (0)