Skip to content

Commit 7390687

Browse files
committed
implement kv cache (inference_context) for eagle training
Signed-off-by: Ye Yu <[email protected]>
1 parent 5982faf commit 7390687

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
2727
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
2828
from megatron.core.extensions.transformer_engine import TENorm
29+
from megatron.core.inference.contexts import StaticInferenceContext
2930
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
3031
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
3132
from megatron.core.models.gpt import GPTModel
@@ -240,23 +241,15 @@ def set_multi_step_attention_mask(attn_mask, step):
240241
s = attn_mask.shape[-1]
241242
for iter in range(2, step + 1):
242243
# iter starts from 2nd step
243-
zero_mask = attn_mask.new_ones(
244-
attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2], s
245-
).bool()
246-
mask_0 = attn_mask.clone().detach()[:, :, -s:, :]
244+
mask_0 = attn_mask.clone().detach()
247245
mask_0[:, :, iter - 2, :] = True
248246
mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:]
249247
mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool()
250248
for i in range(iter - 1, s - 1):
251249
mask_1[:, :, i, i] = False
252250

253-
attn_mask = torch.cat(
254-
(
255-
torch.cat((attn_mask, zero_mask), dim=-1),
256-
torch.cat((mask_0, mask_1), dim=-1),
257-
),
258-
dim=-2,
259-
)
251+
attn_mask = torch.cat((mask_0, mask_1), dim=-1)
252+
260253
return attn_mask
261254

262255

@@ -516,6 +509,7 @@ def forward(
516509
rotary_pos_emb: torch.Tensor = None,
517510
inference_params: InferenceParams = None,
518511
packed_seq_params: PackedSeqParams = None,
512+
inference_context: StaticInferenceContext | None = None,
519513
extra_block_kwargs: dict | None = None,
520514
) -> torch.Tensor:
521515
"""Forward function."""
@@ -556,6 +550,7 @@ def forward(
556550
inference_params=inference_params,
557551
rotary_pos_emb=rotary_pos_emb,
558552
packed_seq_params=packed_seq_params,
553+
inference_context=inference_context,
559554
**(extra_block_kwargs or {}),
560555
)
561556

@@ -962,6 +957,7 @@ def _eagle_forward(
962957
output_weight,
963958
inference_params: InferenceParams = None,
964959
packed_seq_params: PackedSeqParams = None,
960+
inference_context: StaticInferenceContext | None = None,
965961
extra_block_kwargs: dict | None = None,
966962
):
967963
eagle_hidden_states, eagle_hidden_states_pre_final_layernorm = self.eagle_module(
@@ -971,15 +967,23 @@ def _eagle_forward(
971967
eagle_inputs["rotary_pos_emb"],
972968
inference_params=inference_params,
973969
packed_seq_params=packed_seq_params,
970+
inference_context=inference_context,
974971
**(extra_block_kwargs or {}),
975972
)
976973

974+
# Update inference_context.sequence_len_offset after each call of eagle_module
975+
inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1]
976+
977977
if hasattr(self.eagle_module, "eagle_output_layer"):
978978
eagle_logits, _ = self.eagle_module.eagle_output_layer(eagle_hidden_states)
979979
else:
980980
eagle_logits, _ = self.output_layer(eagle_hidden_states, weight=output_weight)
981981

982-
return eagle_hidden_states, eagle_logits, eagle_hidden_states_pre_final_layernorm
982+
return (
983+
eagle_hidden_states,
984+
eagle_logits,
985+
eagle_hidden_states_pre_final_layernorm,
986+
)
983987

984988
def forward(
985989
self,
@@ -1033,6 +1037,11 @@ def forward(
10331037
output_weight = self.shared_embedding_or_output_weight()
10341038
logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight)
10351039

1040+
# EAGLE kv cache
1041+
eagle_inference_context = StaticInferenceContext(
1042+
input_ids.shape[0], input_ids.shape[1] * self.eagle_config.parallel_draft_step * 4
1043+
)
1044+
10361045
if self.eagle_offline:
10371046
eagle_module_input_hidden_states = self._get_eagle_input_hidden_states(
10381047
aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state
@@ -1075,6 +1084,7 @@ def forward(
10751084
output_weight,
10761085
inference_params=inference_params,
10771086
packed_seq_params=packed_seq_params,
1087+
inference_context=eagle_inference_context,
10781088
**(extra_block_kwargs or {}),
10791089
)
10801090

@@ -1141,6 +1151,7 @@ def forward(
11411151
output_weight,
11421152
inference_params=inference_params,
11431153
packed_seq_params=packed_seq_params,
1154+
inference_context=eagle_inference_context,
11441155
**(extra_block_kwargs or {}),
11451156
)
11461157
eagle_logits_1 = eagle_logits_2x[-labels.shape[1] * self.eagle_config.parallel_draft_step :]
@@ -1186,6 +1197,7 @@ def forward(
11861197
output_weight,
11871198
inference_params=inference_params,
11881199
packed_seq_params=packed_seq_params,
1200+
inference_context=eagle_inference_context,
11891201
**(extra_block_kwargs or {}),
11901202
)
11911203

@@ -1232,6 +1244,7 @@ def forward(
12321244
output_weight,
12331245
inference_params=inference_params,
12341246
packed_seq_params=packed_seq_params,
1247+
inference_context=eagle_inference_context,
12351248
**(extra_block_kwargs or {}),
12361249
)
12371250

@@ -1443,6 +1456,8 @@ def pseudo_speculative_generate(
14431456
):
14441457
"""Pseudo generate of the EAGLE GPTModel.
14451458
1459+
This function does not support kv cache as sequence parallel may be enabled.
1460+
14461461
Returns:
14471462
base_token (torch.Tensor): token from base model
14481463
draft_tokens (torch.Tensor): draft tokens from eagle module

0 commit comments

Comments
 (0)