We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8ad781a commit 4542a6bCopy full SHA for 4542a6b
modelopt/torch/speculative/plugins/megatron_eagle.py
@@ -1408,6 +1408,11 @@ def pseudo_speculative_generate(
1408
1409
draft_tokens.append(draft_token)
1410
1411
+ # Remove mask tokens from eagle_ids before adding draft_token
1412
+ # Remove added hidden_states before
1413
+ eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1]
1414
+ hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1]
1415
+
1416
eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1)
1417
hidden_states = torch.cat(
1418
(
0 commit comments