Skip to content

Commit 753c8a1

Browse files
committed
debug: move learnable parallel draft embedding and hidden_states after base model param freeze
Signed-off-by: Ye Yu <[email protected]>
1 parent d3e1e0a commit 753c8a1

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -674,15 +674,6 @@ def modify(
674674
"Only logit distillation is supported when draft_vocab_size != vocab_size!"
675675
)
676676

677-
# Set up learnable parallel draft embeddings and hidden_states
678-
if self.eagle_config.parallel_draft_step > 1:
679-
self.parallel_draft_embeddings = torch.nn.Parameter(
680-
torch.rand(self.eagle_config.parallel_draft_step - 1, self.eagle_config.hidden_size)
681-
)
682-
self.parallel_draft_hidden_states = torch.nn.Parameter(
683-
torch.rand(self.eagle_config.parallel_draft_step - 1, self.eagle_config.hidden_size)
684-
)
685-
686677
# Use default aux_hidden_state layers if use_aux_hidden_state is True
687678
# but no layer id is given
688679
# layer ids are not used in offline eagle, but we need to set this to have correct fc_input_size_multiplier
@@ -764,6 +755,19 @@ def modify(
764755
# Eagle loss functions
765756
self.kld = logits_kld_loss
766757

758+
# Set up learnable parallel draft embeddings and hidden_states
759+
if self.eagle_config.parallel_draft_step > 1:
760+
self.parallel_draft_embeddings = torch.nn.Parameter(
761+
torch.rand(
762+
self.eagle_config.parallel_draft_step - 1, self.eagle_config.hidden_size
763+
)
764+
)
765+
self.parallel_draft_hidden_states = torch.nn.Parameter(
766+
torch.rand(
767+
self.eagle_config.parallel_draft_step - 1, self.eagle_config.hidden_size
768+
)
769+
)
770+
767771
def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True):
768772
"""When _aux_hidden_states is not empty for online, then this is EAGLE-3.
769773
@@ -1389,9 +1393,7 @@ def pseudo_speculative_generate(
13891393
for _ in range(self.eagle_config.parallel_draft_step - 1):
13901394
# Pad dummy eagle_ids and hidden_states for parallel draft
13911395
# They will be replaced by parallel draft embeddings and hidden_states after padding
1392-
eagle_ids = torch.cat(
1393-
(eagle_ids, torch.zeros(1, 1).to(eagle_ids.dtype).to(eagle_ids.device)), dim=-1
1394-
)
1396+
eagle_ids = torch.cat((eagle_ids, eagle_ids[:, -1:]), dim=-1)
13951397
hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0)
13961398
padded_eagle_ids, seq_len, padded_hidden_states = right_padding(
13971399
eagle_ids, hidden_states

0 commit comments

Comments
 (0)