diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 0b2a886f5..c2447367e 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -524,9 +524,11 @@ def __init__( if self._num_aux_hidden_states > 0: # Register forward hook to the last EAGLE3 layer to extract the pre-norm hidden_state # for eagle3 auto regression. - layer = self.decoder.layers[-1] - layer.register_forward_hook(self._eagle3_layer_forward_hook) + last_layer = self.decoder.layers[-1] + last_layer.register_forward_hook(self._eagle3_layer_forward_hook) + # The first EAGLE3 layer needs to be specialized. + layer = self.decoder.layers[0] self_attention = layer.self_attention if not isinstance(self_attention, SelfAttention): raise ValueError("EAGLE-3 only support SelfAttention (MHA, GQA).") diff --git a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py b/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py index 48719e4b3..0b9eda318 100644 --- a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py +++ b/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py @@ -32,6 +32,12 @@ from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel from modelopt.torch.speculative.utils import Tree, get_default_attention_mask_and_position_ids +ALGO_TO_CONFIG = { + "eagle1": mtsp.config.EAGLE1_DEFAULT_CFG, + "eagle3": mtsp.config.EAGLE3_DEFAULT_CFG, + "eagle-mtp": mtsp.config.EAGLE_MTP_DEFAULT_CFG, +} + def _test_speculative_gpt_model( algo, num_medusa_heads_or_eagle_layers, activation_func, normalization, rank, size @@ -64,18 +70,42 @@ def _test_speculative_gpt_model( # Type checking assert isinstance(model, _DynamicMedusaGPTModel) - elif algo == "eagle": - config = {"eagle_architecture_config": deepcopy(default_eagle_config)} - config["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size - config["eagle_architecture_config"]["vocab_size"] = model.vocab_size - config["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size + elif algo in {"eagle1", "eagle3"}: + mtsp_config = ALGO_TO_CONFIG[algo] + + mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] = ( + num_medusa_heads_or_eagle_layers + ) + mtsp_config["config"]["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size + mtsp_config["config"]["eagle_architecture_config"]["vocab_size"] = model.vocab_size + mtsp_config["config"]["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size - model = mtsp.convert(model, [("eagle", config)]) + model = mtsp.convert(model, mtsp_config) # Type checking assert isinstance(model, _DynamicEagleGPTModel) else: - raise ValueError("Only algo={eagle, medusa} are supported!") + raise ValueError("Only algo={eagle1, eagle3, medusa} are supported!") + + if algo == "eagle3": + first_layer = model.eagle_module.decoder.layers[0] + last_layer = model.eagle_module.decoder.layers[-1] + # Eagle3 QKV input_dim is 2x of hidden_size + assert ( + first_layer.self_attention.linear_qkv.weight.shape[-1] == model.config.hidden_size * 2 + ) + # Eagle3 attention has a forward_pre_hook to handle additional features to be concatenated + assert len(first_layer.self_attention._forward_pre_hooks) > 0 + # Eagle3 last layer has a forward hook to extrat the pre_norm hidden_state + assert len(last_layer._forward_hooks) > 0 + elif algo == "eagle1": + first_layer = model.eagle_module.decoder.layers[0] + last_layer = model.eagle_module.decoder.layers[-1] + # Eagle1 QKV input_dim the same as hidden_size + assert first_layer.self_attention.linear_qkv.weight.shape[-1] == model.config.hidden_size + # No forward_hook or forward_pre_hook are needed + assert len(first_layer.self_attention._forward_pre_hooks) == 0 + assert len(last_layer._forward_hooks) == 0 # Bfloat16 model = model.to(torch.bfloat16) @@ -104,7 +134,7 @@ def _test_speculative_gpt_model( assert medusa_loss.shape[0] == batch_size assert medusa_loss.shape[1] == max_sequence_length - elif algo == "eagle": + elif algo in {"eagle1", "eagle3"}: labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() eagle_loss = model(prompt_tokens, position_ids, attention_mask, labels=labels) @@ -115,8 +145,10 @@ def _test_speculative_gpt_model( @pytest.mark.parametrize( ("algo", "num_medusa_heads_or_eagle_layers", "activation_func", "normalization"), [ - ("eagle", 1, "squared_relu", "LayerNorm"), # MHA - ("eagle", 2, "swiglu", "RMSNorm"), # GQA + ("eagle1", 1, "squared_relu", "LayerNorm"), # MHA + ("eagle1", 2, "swiglu", "RMSNorm"), # GQA + ("eagle3", 1, "swiglu", "RMSNorm"), # GQA + ("eagle3", 2, "swiglu", "RMSNorm"), # GQA ("medusa", 1, "squared_relu", "LayerNorm"), # MHA ("medusa", 2, "swiglu", "RMSNorm"), # GQA ],