32
32
from modelopt .torch .speculative .plugins .megatron_medusa import _DynamicMedusaGPTModel
33
33
from modelopt .torch .speculative .utils import Tree , get_default_attention_mask_and_position_ids
34
34
35
+ ALGO_TO_CONFIG = {
36
+ "eagle1" : mtsp .config .EAGLE1_DEFAULT_CFG ,
37
+ "eagle3" : mtsp .config .EAGLE3_DEFAULT_CFG ,
38
+ "eagle-mtp" : mtsp .config .EAGLE_MTP_DEFAULT_CFG ,
39
+ }
40
+
35
41
36
42
def _test_speculative_gpt_model (
37
43
algo , num_medusa_heads_or_eagle_layers , activation_func , normalization , rank , size
@@ -64,18 +70,42 @@ def _test_speculative_gpt_model(
64
70
65
71
# Type checking
66
72
assert isinstance (model , _DynamicMedusaGPTModel )
67
- elif algo == "eagle" :
68
- config = {"eagle_architecture_config" : deepcopy (default_eagle_config )}
69
- config ["eagle_architecture_config" ]["hidden_size" ] = model .config .hidden_size
70
- config ["eagle_architecture_config" ]["vocab_size" ] = model .vocab_size
71
- config ["eagle_architecture_config" ]["draft_vocab_size" ] = model .vocab_size
73
+ elif algo in {"eagle1" , "eagle3" }:
74
+ mtsp_config = ALGO_TO_CONFIG [algo ]
75
+
76
+ mtsp_config ["config" ]["eagle_architecture_config" ]["num_hidden_layers" ] = (
77
+ num_medusa_heads_or_eagle_layers
78
+ )
79
+ mtsp_config ["config" ]["eagle_architecture_config" ]["hidden_size" ] = model .config .hidden_size
80
+ mtsp_config ["config" ]["eagle_architecture_config" ]["vocab_size" ] = model .vocab_size
81
+ mtsp_config ["config" ]["eagle_architecture_config" ]["draft_vocab_size" ] = model .vocab_size
72
82
73
- model = mtsp .convert (model , [( "eagle" , config )] )
83
+ model = mtsp .convert (model , mtsp_config )
74
84
75
85
# Type checking
76
86
assert isinstance (model , _DynamicEagleGPTModel )
77
87
else :
78
- raise ValueError ("Only algo={eagle, medusa} are supported!" )
88
+ raise ValueError ("Only algo={eagle1, eagle3, medusa} are supported!" )
89
+
90
+ if algo == "eagle3" :
91
+ first_layer = model .eagle_module .decoder .layers [0 ]
92
+ last_layer = model .eagle_module .decoder .layers [- 1 ]
93
+ # Eagle3 QKV input_dim is 2x of hidden_size
94
+ assert (
95
+ first_layer .self_attention .linear_qkv .weight .shape [- 1 ] == model .config .hidden_size * 2
96
+ )
97
+ # Eagle3 attention has a forward_pre_hook to handle additional features to be concatenated
98
+ assert len (first_layer .self_attention ._forward_pre_hooks ) > 0
99
+ # Eagle3 last layer has a forward hook to extrat the pre_norm hidden_state
100
+ assert len (last_layer ._forward_hooks ) > 0
101
+ elif algo == "eagle1" :
102
+ first_layer = model .eagle_module .decoder .layers [0 ]
103
+ last_layer = model .eagle_module .decoder .layers [- 1 ]
104
+ # Eagle1 QKV input_dim the same as hidden_size
105
+ assert first_layer .self_attention .linear_qkv .weight .shape [- 1 ] == model .config .hidden_size
106
+ # No forward_hook or forward_pre_hook are needed
107
+ assert len (first_layer .self_attention ._forward_pre_hooks ) == 0
108
+ assert len (last_layer ._forward_hooks ) == 0
79
109
80
110
# Bfloat16
81
111
model = model .to (torch .bfloat16 )
@@ -104,7 +134,7 @@ def _test_speculative_gpt_model(
104
134
105
135
assert medusa_loss .shape [0 ] == batch_size
106
136
assert medusa_loss .shape [1 ] == max_sequence_length
107
- elif algo == "eagle" :
137
+ elif algo in { "eagle1" , "eagle3" } :
108
138
labels = torch .randint (0 , vocab_size , (batch_size , max_sequence_length )).cuda ()
109
139
eagle_loss = model (prompt_tokens , position_ids , attention_mask , labels = labels )
110
140
@@ -115,8 +145,10 @@ def _test_speculative_gpt_model(
115
145
@pytest .mark .parametrize (
116
146
("algo" , "num_medusa_heads_or_eagle_layers" , "activation_func" , "normalization" ),
117
147
[
118
- ("eagle" , 1 , "squared_relu" , "LayerNorm" ), # MHA
119
- ("eagle" , 2 , "swiglu" , "RMSNorm" ), # GQA
148
+ ("eagle1" , 1 , "squared_relu" , "LayerNorm" ), # MHA
149
+ ("eagle1" , 2 , "swiglu" , "RMSNorm" ), # GQA
150
+ ("eagle3" , 1 , "swiglu" , "RMSNorm" ), # GQA
151
+ ("eagle3" , 2 , "swiglu" , "RMSNorm" ), # GQA
120
152
("medusa" , 1 , "squared_relu" , "LayerNorm" ), # MHA
121
153
("medusa" , 2 , "swiglu" , "RMSNorm" ), # GQA
122
154
],
0 commit comments