2626from megatron .core .dist_checkpointing .mapping import ShardedStateDict
2727from megatron .core .dist_checkpointing .utils import replace_prefix_for_sharding
2828from megatron .core .extensions .transformer_engine import TENorm
29+ from megatron .core .inference .contexts import StaticInferenceContext
2930from megatron .core .models .common .embeddings .language_model_embedding import LanguageModelEmbedding
3031from megatron .core .models .common .embeddings .rotary_pos_embedding import RotaryEmbedding
3132from megatron .core .models .gpt import GPTModel
@@ -240,23 +241,15 @@ def set_multi_step_attention_mask(attn_mask, step):
240241 s = attn_mask .shape [- 1 ]
241242 for iter in range (2 , step + 1 ):
242243 # iter starts from 2nd step
243- zero_mask = attn_mask .new_ones (
244- attn_mask .shape [0 ], attn_mask .shape [1 ], attn_mask .shape [2 ], s
245- ).bool ()
246- mask_0 = attn_mask .clone ().detach ()[:, :, - s :, :]
244+ mask_0 = attn_mask .clone ().detach ()
247245 mask_0 [:, :, iter - 2 , :] = True
248246 mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
249247 mask_1 = attn_mask .new_ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
250248 for i in range (iter - 1 , s - 1 ):
251249 mask_1 [:, :, i , i ] = False
252250
253- attn_mask = torch .cat (
254- (
255- torch .cat ((attn_mask , zero_mask ), dim = - 1 ),
256- torch .cat ((mask_0 , mask_1 ), dim = - 1 ),
257- ),
258- dim = - 2 ,
259- )
251+ attn_mask = torch .cat ((mask_0 , mask_1 ), dim = - 1 )
252+
260253 return attn_mask
261254
262255
@@ -516,6 +509,7 @@ def forward(
516509 rotary_pos_emb : torch .Tensor = None ,
517510 inference_params : InferenceParams = None ,
518511 packed_seq_params : PackedSeqParams = None ,
512+ inference_context : StaticInferenceContext | None = None ,
519513 extra_block_kwargs : dict | None = None ,
520514 ) -> torch .Tensor :
521515 """Forward function."""
@@ -556,6 +550,7 @@ def forward(
556550 inference_params = inference_params ,
557551 rotary_pos_emb = rotary_pos_emb ,
558552 packed_seq_params = packed_seq_params ,
553+ inference_context = inference_context ,
559554 ** (extra_block_kwargs or {}),
560555 )
561556
@@ -962,6 +957,7 @@ def _eagle_forward(
962957 output_weight ,
963958 inference_params : InferenceParams = None ,
964959 packed_seq_params : PackedSeqParams = None ,
960+ inference_context : StaticInferenceContext | None = None ,
965961 extra_block_kwargs : dict | None = None ,
966962 ):
967963 eagle_hidden_states , eagle_hidden_states_pre_final_layernorm = self .eagle_module (
@@ -971,15 +967,23 @@ def _eagle_forward(
971967 eagle_inputs ["rotary_pos_emb" ],
972968 inference_params = inference_params ,
973969 packed_seq_params = packed_seq_params ,
970+ inference_context = inference_context ,
974971 ** (extra_block_kwargs or {}),
975972 )
976973
974+ # Update inference_context.sequence_len_offset after each call of eagle_module
975+ inference_context .sequence_len_offset += eagle_inputs ["input_ids" ].shape [1 ]
976+
977977 if hasattr (self .eagle_module , "eagle_output_layer" ):
978978 eagle_logits , _ = self .eagle_module .eagle_output_layer (eagle_hidden_states )
979979 else :
980980 eagle_logits , _ = self .output_layer (eagle_hidden_states , weight = output_weight )
981981
982- return eagle_hidden_states , eagle_logits , eagle_hidden_states_pre_final_layernorm
982+ return (
983+ eagle_hidden_states ,
984+ eagle_logits ,
985+ eagle_hidden_states_pre_final_layernorm ,
986+ )
983987
984988 def forward (
985989 self ,
@@ -1033,6 +1037,11 @@ def forward(
10331037 output_weight = self .shared_embedding_or_output_weight ()
10341038 logits_sbh , _ = self .output_layer (hidden_states , weight = output_weight )
10351039
1040+ # EAGLE kv cache
1041+ eagle_inference_context = StaticInferenceContext (
1042+ input_ids .shape [0 ], input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * 4
1043+ )
1044+
10361045 if self .eagle_offline :
10371046 eagle_module_input_hidden_states = self ._get_eagle_input_hidden_states (
10381047 aux_hidden_states , apply_fc = self .eagle_config .use_aux_hidden_state
@@ -1075,6 +1084,7 @@ def forward(
10751084 output_weight ,
10761085 inference_params = inference_params ,
10771086 packed_seq_params = packed_seq_params ,
1087+ inference_context = eagle_inference_context ,
10781088 ** (extra_block_kwargs or {}),
10791089 )
10801090
@@ -1141,6 +1151,7 @@ def forward(
11411151 output_weight ,
11421152 inference_params = inference_params ,
11431153 packed_seq_params = packed_seq_params ,
1154+ inference_context = eagle_inference_context ,
11441155 ** (extra_block_kwargs or {}),
11451156 )
11461157 eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
@@ -1186,6 +1197,7 @@ def forward(
11861197 output_weight ,
11871198 inference_params = inference_params ,
11881199 packed_seq_params = packed_seq_params ,
1200+ inference_context = eagle_inference_context ,
11891201 ** (extra_block_kwargs or {}),
11901202 )
11911203
@@ -1232,6 +1244,7 @@ def forward(
12321244 output_weight ,
12331245 inference_params = inference_params ,
12341246 packed_seq_params = packed_seq_params ,
1247+ inference_context = eagle_inference_context ,
12351248 ** (extra_block_kwargs or {}),
12361249 )
12371250
@@ -1443,6 +1456,8 @@ def pseudo_speculative_generate(
14431456 ):
14441457 """Pseudo generate of the EAGLE GPTModel.
14451458
1459+ This function does not support kv cache as sequence parallel may be enabled.
1460+
14461461 Returns:
14471462 base_token (torch.Tensor): token from base model
14481463 draft_tokens (torch.Tensor): draft tokens from eagle module
0 commit comments