Skip to content

Commit bfdd637

Browse files
committed
debug
Signed-off-by: Ye Yu <[email protected]>
1 parent 753c8a1 commit bfdd637

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,15 @@ def __init__(
480480
skip_weight_param_allocation=False,
481481
)
482482

483+
# Set up learnable parallel draft embeddings and hidden_states
484+
if config.parallel_draft_step > 1:
485+
self.parallel_draft_embeddings = torch.nn.Parameter(
486+
torch.rand(config.parallel_draft_step - 1, config.hidden_size)
487+
)
488+
self.parallel_draft_hidden_states = torch.nn.Parameter(
489+
torch.rand(config.parallel_draft_step - 1, config.hidden_size)
490+
)
491+
483492
def _get_eagle_transformer_layer_spec(self, config):
484493
"""Get the TransformerLayer implementation spec.
485494
@@ -755,19 +764,6 @@ def modify(
755764
# Eagle loss functions
756765
self.kld = logits_kld_loss
757766

758-
# Set up learnable parallel draft embeddings and hidden_states
759-
if self.eagle_config.parallel_draft_step > 1:
760-
self.parallel_draft_embeddings = torch.nn.Parameter(
761-
torch.rand(
762-
self.eagle_config.parallel_draft_step - 1, self.eagle_config.hidden_size
763-
)
764-
)
765-
self.parallel_draft_hidden_states = torch.nn.Parameter(
766-
torch.rand(
767-
self.eagle_config.parallel_draft_step - 1, self.eagle_config.hidden_size
768-
)
769-
)
770-
771767
def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True):
772768
"""When _aux_hidden_states is not empty for online, then this is EAGLE-3.
773769
@@ -824,10 +820,10 @@ def _get_eagle_module_inputs(
824820
)
825821
eagle_inputs["hidden_states"] = hidden_states
826822
else:
827-
eagle_inputs["embedding"] = self.parallel_draft_embeddings[
823+
eagle_inputs["embedding"] = self.eagle_module.parallel_draft_embeddings[
828824
parallel_draft_index - 1
829825
].repeat(hidden_states.shape[0], hidden_states.shape[1], 1)
830-
eagle_inputs["hidden_states"] = self.parallel_draft_hidden_states[
826+
eagle_inputs["hidden_states"] = self.eagle_module.parallel_draft_hidden_states[
831827
parallel_draft_index - 1
832828
].repeat(hidden_states.shape[0], hidden_states.shape[1], 1)
833829

@@ -1417,10 +1413,10 @@ def pseudo_speculative_generate(
14171413
# parallel_draft_embeddings and parallel_draft_hidden_states
14181414
gathered_embedding[
14191415
seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len
1420-
] = self.parallel_draft_embeddings.unsqueeze(1)
1416+
] = self.eagle_module.parallel_draft_embeddings.unsqueeze(1)
14211417
padded_hidden_states[
14221418
seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len
1423-
] = self.parallel_draft_hidden_states.unsqueeze(1)
1419+
] = self.eagle_module.parallel_draft_hidden_states.unsqueeze(1)
14241420
if self.config.sequence_parallel:
14251421
padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states)
14261422
embeddings = scatter_to_sequence_parallel_region(gathered_embedding)

0 commit comments

Comments
 (0)