26
26
from megatron .core .dist_checkpointing .mapping import ShardedStateDict
27
27
from megatron .core .dist_checkpointing .utils import replace_prefix_for_sharding
28
28
from megatron .core .extensions .transformer_engine import TENorm
29
+ from megatron .core .inference .contexts import StaticInferenceContext
29
30
from megatron .core .models .common .embeddings .language_model_embedding import LanguageModelEmbedding
30
31
from megatron .core .models .common .embeddings .rotary_pos_embedding import RotaryEmbedding
31
32
from megatron .core .models .gpt import GPTModel
@@ -240,23 +241,15 @@ def set_multi_step_attention_mask(attn_mask, step):
240
241
s = attn_mask .shape [- 1 ]
241
242
for iter in range (2 , step + 1 ):
242
243
# iter starts from 2nd step
243
- zero_mask = attn_mask .new_ones (
244
- attn_mask .shape [0 ], attn_mask .shape [1 ], attn_mask .shape [2 ], s
245
- ).bool ()
246
- mask_0 = attn_mask .clone ().detach ()[:, :, - s :, :]
244
+ mask_0 = attn_mask .clone ().detach ()
247
245
mask_0 [:, :, iter - 2 , :] = True
248
246
mask_0 [:, :, :, :- 1 ] = mask_0 [:, :, :, 1 :]
249
247
mask_1 = attn_mask .new_ones (attn_mask .shape [0 ], attn_mask .shape [1 ], s , s ).bool ()
250
248
for i in range (iter - 1 , s - 1 ):
251
249
mask_1 [:, :, i , i ] = False
252
250
253
- attn_mask = torch .cat (
254
- (
255
- torch .cat ((attn_mask , zero_mask ), dim = - 1 ),
256
- torch .cat ((mask_0 , mask_1 ), dim = - 1 ),
257
- ),
258
- dim = - 2 ,
259
- )
251
+ attn_mask = torch .cat ((mask_0 , mask_1 ), dim = - 1 )
252
+
260
253
return attn_mask
261
254
262
255
@@ -516,6 +509,7 @@ def forward(
516
509
rotary_pos_emb : torch .Tensor = None ,
517
510
inference_params : InferenceParams = None ,
518
511
packed_seq_params : PackedSeqParams = None ,
512
+ inference_context : StaticInferenceContext | None = None ,
519
513
extra_block_kwargs : dict | None = None ,
520
514
) -> torch .Tensor :
521
515
"""Forward function."""
@@ -556,6 +550,7 @@ def forward(
556
550
inference_params = inference_params ,
557
551
rotary_pos_emb = rotary_pos_emb ,
558
552
packed_seq_params = packed_seq_params ,
553
+ inference_context = inference_context ,
559
554
** (extra_block_kwargs or {}),
560
555
)
561
556
@@ -962,6 +957,7 @@ def _eagle_forward(
962
957
output_weight ,
963
958
inference_params : InferenceParams = None ,
964
959
packed_seq_params : PackedSeqParams = None ,
960
+ inference_context : StaticInferenceContext | None = None ,
965
961
extra_block_kwargs : dict | None = None ,
966
962
):
967
963
eagle_hidden_states , eagle_hidden_states_pre_final_layernorm = self .eagle_module (
@@ -971,15 +967,23 @@ def _eagle_forward(
971
967
eagle_inputs ["rotary_pos_emb" ],
972
968
inference_params = inference_params ,
973
969
packed_seq_params = packed_seq_params ,
970
+ inference_context = inference_context ,
974
971
** (extra_block_kwargs or {}),
975
972
)
976
973
974
+ # Update inference_context.sequence_len_offset after each call of eagle_module
975
+ inference_context .sequence_len_offset += eagle_inputs ["input_ids" ].shape [1 ]
976
+
977
977
if hasattr (self .eagle_module , "eagle_output_layer" ):
978
978
eagle_logits , _ = self .eagle_module .eagle_output_layer (eagle_hidden_states )
979
979
else :
980
980
eagle_logits , _ = self .output_layer (eagle_hidden_states , weight = output_weight )
981
981
982
- return eagle_hidden_states , eagle_logits , eagle_hidden_states_pre_final_layernorm
982
+ return (
983
+ eagle_hidden_states ,
984
+ eagle_logits ,
985
+ eagle_hidden_states_pre_final_layernorm ,
986
+ )
983
987
984
988
def forward (
985
989
self ,
@@ -1033,6 +1037,11 @@ def forward(
1033
1037
output_weight = self .shared_embedding_or_output_weight ()
1034
1038
logits_sbh , _ = self .output_layer (hidden_states , weight = output_weight )
1035
1039
1040
+ # EAGLE kv cache
1041
+ eagle_inference_context = StaticInferenceContext (
1042
+ input_ids .shape [0 ], input_ids .shape [1 ] * self .eagle_config .parallel_draft_step * 4
1043
+ )
1044
+
1036
1045
if self .eagle_offline :
1037
1046
eagle_module_input_hidden_states = self ._get_eagle_input_hidden_states (
1038
1047
aux_hidden_states , apply_fc = self .eagle_config .use_aux_hidden_state
@@ -1075,6 +1084,7 @@ def forward(
1075
1084
output_weight ,
1076
1085
inference_params = inference_params ,
1077
1086
packed_seq_params = packed_seq_params ,
1087
+ inference_context = eagle_inference_context ,
1078
1088
** (extra_block_kwargs or {}),
1079
1089
)
1080
1090
@@ -1141,6 +1151,7 @@ def forward(
1141
1151
output_weight ,
1142
1152
inference_params = inference_params ,
1143
1153
packed_seq_params = packed_seq_params ,
1154
+ inference_context = eagle_inference_context ,
1144
1155
** (extra_block_kwargs or {}),
1145
1156
)
1146
1157
eagle_logits_1 = eagle_logits_2x [- labels .shape [1 ] * self .eagle_config .parallel_draft_step :]
@@ -1186,6 +1197,7 @@ def forward(
1186
1197
output_weight ,
1187
1198
inference_params = inference_params ,
1188
1199
packed_seq_params = packed_seq_params ,
1200
+ inference_context = eagle_inference_context ,
1189
1201
** (extra_block_kwargs or {}),
1190
1202
)
1191
1203
@@ -1232,6 +1244,7 @@ def forward(
1232
1244
output_weight ,
1233
1245
inference_params = inference_params ,
1234
1246
packed_seq_params = packed_seq_params ,
1247
+ inference_context = eagle_inference_context ,
1235
1248
** (extra_block_kwargs or {}),
1236
1249
)
1237
1250
@@ -1443,6 +1456,8 @@ def pseudo_speculative_generate(
1443
1456
):
1444
1457
"""Pseudo generate of the EAGLE GPTModel.
1445
1458
1459
+ This function does not support kv cache as sequence parallel may be enabled.
1460
+
1446
1461
Returns:
1447
1462
base_token (torch.Tensor): token from base model
1448
1463
draft_tokens (torch.Tensor): draft tokens from eagle module
0 commit comments