3232from modelopt .torch .speculative .plugins .megatron_medusa import _DynamicMedusaGPTModel
3333from modelopt .torch .speculative .utils import Tree , get_default_attention_mask_and_position_ids
3434
35+ ALGO_TO_CONFIG = {
36+ "eagle1" : mtsp .config .EAGLE1_DEFAULT_CFG ,
37+ "eagle3" : mtsp .config .EAGLE3_DEFAULT_CFG ,
38+ "eagle-mtp" : mtsp .config .EAGLE_MTP_DEFAULT_CFG ,
39+ }
40+
3541
3642def _test_speculative_gpt_model (
3743 algo , num_medusa_heads_or_eagle_layers , activation_func , normalization , rank , size
@@ -64,18 +70,42 @@ def _test_speculative_gpt_model(
6470
6571 # Type checking
6672 assert isinstance (model , _DynamicMedusaGPTModel )
67- elif algo == "eagle" :
68- config = {"eagle_architecture_config" : deepcopy (default_eagle_config )}
69- config ["eagle_architecture_config" ]["hidden_size" ] = model .config .hidden_size
70- config ["eagle_architecture_config" ]["vocab_size" ] = model .vocab_size
71- config ["eagle_architecture_config" ]["draft_vocab_size" ] = model .vocab_size
73+ elif algo in {"eagle1" , "eagle3" }:
74+ mtsp_config = ALGO_TO_CONFIG [algo ]
75+
76+ mtsp_config ["config" ]["eagle_architecture_config" ]["num_hidden_layers" ] = (
77+ num_medusa_heads_or_eagle_layers
78+ )
79+ mtsp_config ["config" ]["eagle_architecture_config" ]["hidden_size" ] = model .config .hidden_size
80+ mtsp_config ["config" ]["eagle_architecture_config" ]["vocab_size" ] = model .vocab_size
81+ mtsp_config ["config" ]["eagle_architecture_config" ]["draft_vocab_size" ] = model .vocab_size
7282
73- model = mtsp .convert (model , [( "eagle" , config )] )
83+ model = mtsp .convert (model , mtsp_config )
7484
7585 # Type checking
7686 assert isinstance (model , _DynamicEagleGPTModel )
7787 else :
78- raise ValueError ("Only algo={eagle, medusa} are supported!" )
88+ raise ValueError ("Only algo={eagle1, eagle3, medusa} are supported!" )
89+
90+ if algo == "eagle3" :
91+ first_layer = model .eagle_module .decoder .layers [0 ]
92+ last_layer = model .eagle_module .decoder .layers [- 1 ]
93+ # Eagle3 QKV input_dim is 2x of hidden_size
94+ assert (
95+ first_layer .self_attention .linear_qkv .weight .shape [- 1 ] == model .config .hidden_size * 2
96+ )
97+ # Eagle3 attention has a forward_pre_hook to handle additional features to be concatenated
98+ assert len (first_layer .self_attention ._forward_pre_hooks ) > 0
99+ # Eagle3 last layer has a forward hook to extrat the pre_norm hidden_state
100+ assert len (last_layer ._forward_hooks ) > 0
101+ elif algo == "eagle1" :
102+ first_layer = model .eagle_module .decoder .layers [0 ]
103+ last_layer = model .eagle_module .decoder .layers [- 1 ]
104+ # Eagle1 QKV input_dim the same as hidden_size
105+ assert first_layer .self_attention .linear_qkv .weight .shape [- 1 ] == model .config .hidden_size
106+ # No forward_hook or forward_pre_hook are needed
107+ assert len (first_layer .self_attention ._forward_pre_hooks ) == 0
108+ assert len (last_layer ._forward_hooks ) == 0
79109
80110 # Bfloat16
81111 model = model .to (torch .bfloat16 )
@@ -104,7 +134,7 @@ def _test_speculative_gpt_model(
104134
105135 assert medusa_loss .shape [0 ] == batch_size
106136 assert medusa_loss .shape [1 ] == max_sequence_length
107- elif algo == "eagle" :
137+ elif algo in { "eagle1" , "eagle3" } :
108138 labels = torch .randint (0 , vocab_size , (batch_size , max_sequence_length )).cuda ()
109139 eagle_loss = model (prompt_tokens , position_ids , attention_mask , labels = labels )
110140
@@ -115,8 +145,10 @@ def _test_speculative_gpt_model(
115145@pytest .mark .parametrize (
116146 ("algo" , "num_medusa_heads_or_eagle_layers" , "activation_func" , "normalization" ),
117147 [
118- ("eagle" , 1 , "squared_relu" , "LayerNorm" ), # MHA
119- ("eagle" , 2 , "swiglu" , "RMSNorm" ), # GQA
148+ ("eagle1" , 1 , "squared_relu" , "LayerNorm" ), # MHA
149+ ("eagle1" , 2 , "swiglu" , "RMSNorm" ), # GQA
150+ ("eagle3" , 1 , "swiglu" , "RMSNorm" ), # GQA
151+ ("eagle3" , 2 , "swiglu" , "RMSNorm" ), # GQA
120152 ("medusa" , 1 , "squared_relu" , "LayerNorm" ), # MHA
121153 ("medusa" , 2 , "swiglu" , "RMSNorm" ), # GQA
122154 ],
0 commit comments