Skip to content

Commit 67af656

Browse files
authored
Fixing multi-layer eagle3 forward hook placement (#295)
Signed-off-by: Chenhan Yu <[email protected]>
1 parent 0d279f1 commit 67af656

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,11 @@ def __init__(
524524
if self._num_aux_hidden_states > 0:
525525
# Register forward hook to the last EAGLE3 layer to extract the pre-norm hidden_state
526526
# for eagle3 auto regression.
527-
layer = self.decoder.layers[-1]
528-
layer.register_forward_hook(self._eagle3_layer_forward_hook)
527+
last_layer = self.decoder.layers[-1]
528+
last_layer.register_forward_hook(self._eagle3_layer_forward_hook)
529529

530+
# The first EAGLE3 layer needs to be specialized.
531+
layer = self.decoder.layers[0]
530532
self_attention = layer.self_attention
531533
if not isinstance(self_attention, SelfAttention):
532534
raise ValueError("EAGLE-3 only support SelfAttention (MHA, GQA).")

tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
from modelopt.torch.speculative.plugins.megatron_medusa import _DynamicMedusaGPTModel
3333
from modelopt.torch.speculative.utils import Tree, get_default_attention_mask_and_position_ids
3434

35+
ALGO_TO_CONFIG = {
36+
"eagle1": mtsp.config.EAGLE1_DEFAULT_CFG,
37+
"eagle3": mtsp.config.EAGLE3_DEFAULT_CFG,
38+
"eagle-mtp": mtsp.config.EAGLE_MTP_DEFAULT_CFG,
39+
}
40+
3541

3642
def _test_speculative_gpt_model(
3743
algo, num_medusa_heads_or_eagle_layers, activation_func, normalization, rank, size
@@ -64,18 +70,42 @@ def _test_speculative_gpt_model(
6470

6571
# Type checking
6672
assert isinstance(model, _DynamicMedusaGPTModel)
67-
elif algo == "eagle":
68-
config = {"eagle_architecture_config": deepcopy(default_eagle_config)}
69-
config["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size
70-
config["eagle_architecture_config"]["vocab_size"] = model.vocab_size
71-
config["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size
73+
elif algo in {"eagle1", "eagle3"}:
74+
mtsp_config = ALGO_TO_CONFIG[algo]
75+
76+
mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] = (
77+
num_medusa_heads_or_eagle_layers
78+
)
79+
mtsp_config["config"]["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size
80+
mtsp_config["config"]["eagle_architecture_config"]["vocab_size"] = model.vocab_size
81+
mtsp_config["config"]["eagle_architecture_config"]["draft_vocab_size"] = model.vocab_size
7282

73-
model = mtsp.convert(model, [("eagle", config)])
83+
model = mtsp.convert(model, mtsp_config)
7484

7585
# Type checking
7686
assert isinstance(model, _DynamicEagleGPTModel)
7787
else:
78-
raise ValueError("Only algo={eagle, medusa} are supported!")
88+
raise ValueError("Only algo={eagle1, eagle3, medusa} are supported!")
89+
90+
if algo == "eagle3":
91+
first_layer = model.eagle_module.decoder.layers[0]
92+
last_layer = model.eagle_module.decoder.layers[-1]
93+
# Eagle3 QKV input_dim is 2x of hidden_size
94+
assert (
95+
first_layer.self_attention.linear_qkv.weight.shape[-1] == model.config.hidden_size * 2
96+
)
97+
# Eagle3 attention has a forward_pre_hook to handle additional features to be concatenated
98+
assert len(first_layer.self_attention._forward_pre_hooks) > 0
99+
# Eagle3 last layer has a forward hook to extrat the pre_norm hidden_state
100+
assert len(last_layer._forward_hooks) > 0
101+
elif algo == "eagle1":
102+
first_layer = model.eagle_module.decoder.layers[0]
103+
last_layer = model.eagle_module.decoder.layers[-1]
104+
# Eagle1 QKV input_dim the same as hidden_size
105+
assert first_layer.self_attention.linear_qkv.weight.shape[-1] == model.config.hidden_size
106+
# No forward_hook or forward_pre_hook are needed
107+
assert len(first_layer.self_attention._forward_pre_hooks) == 0
108+
assert len(last_layer._forward_hooks) == 0
79109

80110
# Bfloat16
81111
model = model.to(torch.bfloat16)
@@ -104,7 +134,7 @@ def _test_speculative_gpt_model(
104134

105135
assert medusa_loss.shape[0] == batch_size
106136
assert medusa_loss.shape[1] == max_sequence_length
107-
elif algo == "eagle":
137+
elif algo in {"eagle1", "eagle3"}:
108138
labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda()
109139
eagle_loss = model(prompt_tokens, position_ids, attention_mask, labels=labels)
110140

@@ -115,8 +145,10 @@ def _test_speculative_gpt_model(
115145
@pytest.mark.parametrize(
116146
("algo", "num_medusa_heads_or_eagle_layers", "activation_func", "normalization"),
117147
[
118-
("eagle", 1, "squared_relu", "LayerNorm"), # MHA
119-
("eagle", 2, "swiglu", "RMSNorm"), # GQA
148+
("eagle1", 1, "squared_relu", "LayerNorm"), # MHA
149+
("eagle1", 2, "swiglu", "RMSNorm"), # GQA
150+
("eagle3", 1, "swiglu", "RMSNorm"), # GQA
151+
("eagle3", 2, "swiglu", "RMSNorm"), # GQA
120152
("medusa", 1, "squared_relu", "LayerNorm"), # MHA
121153
("medusa", 2, "swiglu", "RMSNorm"), # GQA
122154
],

0 commit comments

Comments
 (0)