Skip to content

Commit 5982faf

Browse files
committed
fix pseudo spec generate for parallel draft
Signed-off-by: Ye Yu <[email protected]>
1 parent acd0aa6 commit 5982faf

File tree

1 file changed

+17
-24
lines changed

1 file changed

+17
-24
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,12 +1491,11 @@ def pseudo_speculative_generate(
14911491

14921492
draft_tokens = []
14931493
for _ in range(steps):
1494-
if self.eagle_config.parallel_draft_step > 1:
1495-
for i in range(self.eagle_config.parallel_draft_step - 1):
1496-
eagle_ids = torch.cat(
1497-
(eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1
1498-
)
1499-
hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0)
1494+
for i in range(self.eagle_config.parallel_draft_step - 1):
1495+
eagle_ids = torch.cat(
1496+
(eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1
1497+
)
1498+
hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0)
15001499
padded_eagle_ids, seq_len, padded_hidden_states = right_padding(
15011500
eagle_ids, hidden_states
15021501
)
@@ -1530,31 +1529,25 @@ def pseudo_speculative_generate(
15301529
)
15311530
eagle_next_hidden_states_input = eagle_next_hidden_states_input[:seq_len, :, :]
15321531

1533-
if self.eagle_config.parallel_draft_step > 1:
1534-
draft_token = (
1535-
gather_from_tensor_model_parallel_region(eagle_logits)[
1536-
-self.eagle_config.parallel_draft_step :, :, :
1537-
]
1538-
.argmax(dim=-1)
1539-
.transpose(0, 1)
1540-
)
1541-
else:
1542-
draft_token = (
1543-
gather_from_tensor_model_parallel_region(eagle_logits)[-1:, :, :]
1544-
.argmax(dim=-1)
1545-
.transpose(0, 1)
1546-
)
1532+
draft_token = (
1533+
gather_from_tensor_model_parallel_region(eagle_logits)[
1534+
-self.eagle_config.parallel_draft_step :, :, :
1535+
]
1536+
.argmax(dim=-1)
1537+
.transpose(0, 1)
1538+
)
15471539
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
15481540
draft_token += self.eagle_module.d2t[draft_token]
15491541

1550-
if self.eagle_config.parallel_draft_step > 1:
1551-
return base_token, draft_token
1552-
15531542
draft_tokens.append(draft_token)
15541543

15551544
eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1)
15561545
hidden_states = torch.cat(
1557-
(hidden_states, eagle_next_hidden_states_input[-1:, :, :]), dim=0
1546+
(
1547+
hidden_states,
1548+
eagle_next_hidden_states_input[-self.eagle_config.parallel_draft_step :, :, :],
1549+
),
1550+
dim=0,
15581551
)
15591552

15601553
draft_tokens = torch.cat(draft_tokens, dim=-1)

0 commit comments

Comments
 (0)