We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 44e0eb1 commit 13f218cCopy full SHA for 13f218c
modelopt/torch/speculative/plugins/megatron_eagle.py
@@ -1441,8 +1441,9 @@ def pseudo_speculative_generate(
1441
1442
# Remove mask tokens from eagle_ids before adding draft_token
1443
# Remove added hidden_states before
1444
- eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1]
1445
- hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1]
+ if self.eagle_config.parallel_draft_step > 1:
+ eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1]
1446
+ hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1]
1447
1448
eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1)
1449
hidden_states = torch.cat(
0 commit comments