diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2a0e63a3c..b4b9ffeac 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -26,6 +26,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel @@ -139,7 +140,7 @@ def dict_to_config( def mcore_version_higher_than(target_version: str): - """Check if megatron-core is least this version.""" + """Check if megatron-core is greater than this version.""" return Version(megatron.core.__version__) > Version(target_version) @@ -194,46 +195,13 @@ def set_multi_step_attention_mask(attn_mask, step): h0 h1 h2 h3 h4 h5 h6 h7 (base hidden_states) l0 l1 l2 l3 l4 l5 l6 l7 (base labels) + ttt_steps=2 + parallel_draft_step=3 - (1st) | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | - ========================================= - f1 l1 | i1 h0 | x | - f2 l2 | i2 h1 | x x | - f3 l3 | i3 h2 | x x x | - f4 l4 | i4 h3 | x x x x | - f5 l5 | i5 h4 | x x x x x | - f6 l6 | i6 h5 | x x x x x x | - f7 l7 | i7 h6 | x x x x x x x | - -- -- | -- h7 | o o o o o o o o | - ========================================= - - - (2nd) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | - =================================================================== - F1 l1 | i1 h0 | x | | - F2 l2 | i2 h1 | x x | | - F3 l3 | i3 h2 | x x x | | - F4 l4 | i4 h3 | x x x x | | - F5 l5 | i5 h4 | x x x x x | | - F6 l6 | i6 h5 | x x x x x x | | - F7 l7 | i7 h6 | x x x x x x x | | - -- -- | -- h7 | o o o o o o o o | | - =================================================================== - -- -- | i1 -- | | | - G2 l2 | i2 F1 | x o | x | - G3 l3 | i3 F2 | x x o | x | - G4 l4 | i4 F3 | x x x o | x | - G5 l5 | i5 F4 | x x x x o | x | - G6 l6 | i6 F5 | x x x x x o | x | - G7 l7 | i7 F6 | x x x x x x o | x | - -- -- | -- F7 | | | - =================================================================== - - - (3rd) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- -- G2 G3 G4 G5 G6 G7 | + + ttt_step=0 + | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 | ============================================================================================= F1 l1 | i1 h0 | x | | | F2 l2 | i2 h1 | x x | | | @@ -244,126 +212,68 @@ def set_multi_step_attention_mask(attn_mask, step): F7 l7 | i7 h6 | x x x x x x x | | | -- -- | -- h7 | o o o o o o o o | | | ============================================================================================= - -- -- | i1 -- | | | | - G2 l2 | i2 F1 | x o | x | | - G3 l3 | i3 F2 | x x o | x | | - G4 l4 | i4 F3 | x x x o | x | | - G5 l5 | i5 F4 | x x x x o | x | | - G6 l6 | i6 F5 | x x x x x o | x | | - G7 l7 | i7 F6 | x x x x x x o | x | | - -- -- | -- F7 | | | | - ============================================================================================= - -- -- | i1 -- | | | | - -- -- | i2 -- | | | | - H3 l3 | i3 G2 | x o o | x o | x | - H4 l4 | i4 G3 | x x o o | x o | x | - H5 l5 | i5 G4 | x x x o o | x o | x | - H6 l6 | i6 G5 | x x x x o o | x o | x | - H7 l7 | i7 G6 | x x x x x o o | x o | x | - -- -- | -- G7 | | | | + -- -- | m0 M0 | | | | + G2 l2 | m0 M0 | x o | x | | + G3 l3 | m0 M0 | x x o | x | | + G4 l4 | m0 M0 | x x x o | x | | + G5 l5 | m0 M0 | x x x x o | x | | + G6 l6 | m0 M0 | x x x x x o | x | | + G7 l7 | m0 M0 | x x x x x x o | x | | + -- -- | -- M0 | | | | ============================================================================================= - - - (4th) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- -- H3 H4 H5 H6 H7 | - ======================================================================================================================= - F1 l1 | i1 h0 | x | | | | - F2 l2 | i2 h1 | x x | | | | - F3 l3 | i3 h2 | x x x | | | | - F4 l4 | i4 h3 | x x x x | | | | - F5 l5 | i5 h4 | x x x x x | | | | - F6 l6 | i6 h5 | x x x x x x | | | | - F7 l7 | i7 h6 | x x x x x x x | | | | - -- -- | -- h7 | o o o o o o o o | | | | + -- -- | m1 M0 | | | | + -- -- | m1 M1 | | | | + H3 l3 | m1 M1 | x o o | x o | x | + H4 l4 | m1 M1 | x x o o | x o | x | + H5 l5 | m1 M1 | x x x o o | x o | x | + H6 l6 | m1 M1 | x x x x o o | x o | x | + H7 l7 | m1 M1 | x x x x x o o | x o | x | + -- -- | -- M1 | | | | + + + ttt_step=1 + | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | m0 m0 m0 m0 m0 m0 m0 -- | m1 m1 m1 m1 m1 m1 m1 -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | M0 M0 M0 M0 M0 M0 M0 M0 | M1 M1 M1 M1 M1 M1 M1 M1 | ======================================================================================================================= -- -- | i1 -- | | | | | - G2 l2 | i2 F1 | x o | x | | | - G3 l3 | i3 F2 | x x o | x | | | - G4 l4 | i4 F3 | x x x o | x | | | - G5 l5 | i5 F4 | x x x x o | x | | | - G6 l6 | i6 F5 | x x x x x o | x | | | - G7 l7 | i7 F6 | x x x x x x o | x | | | + J2 l2 | i2 F1 | x o | x | | | + J3 l3 | i3 F2 | x x o | x | | | + J4 l4 | i4 F3 | x x x o | x | | | + J5 l5 | i5 F4 | x x x x o | x | | | + J6 l6 | i6 F5 | x x x x x o | x | | | + J7 l7 | i7 F6 | x x x x x x o | x | | | -- -- | -- F7 | | | | | ======================================================================================================================= - -- -- | i1 -- | | | | | - -- -- | i2 -- | | | | | - H3 l3 | i3 G2 | x o o | x o | x | | - H4 l4 | i4 G3 | x x o o | x o | x | | - H5 l5 | i5 G4 | x x x o o | x o | x | | - H6 l6 | i6 G5 | x x x x o o | x o | x | | - H7 l7 | i7 G6 | x x x x x o o | x o | x | | - -- -- | -- G7 | | | | | + -- -- | m0 M0 | | | | | + -- -- | m0 M0 | | | | | + K3 l3 | m0 M0 | x o o | x o | x | | | + K4 l4 | m0 M0 | x x o o | x o | x | | + K5 l5 | m0 M0 | x x x o o | x o | x | | + K6 l6 | m0 M0 | x x x x o o | x o | x | | + K7 l7 | m0 M0 | x x x x x o o | x o | x | | + -- -- | -- M0 | | | | | ======================================================================================================================= - -- -- | i1 -- | | | | | - -- -- | i2 -- | | | | | - -- -- | i3 -- | | | | | - K4 l4 | i4 H3 | x | x | x | x | - K5 l5 | i5 H4 | x x | x | x | x | - K6 l6 | i6 H5 | x x x | x | x | x | - K7 l7 | i7 H6 | x x x x | x | x | x | - -- -- | -- H7 | | | | | + -- -- | m1 M1 | | | | | + -- -- | m1 M1 | | | | | + -- -- | m1 M1 | | | | | + N4 l4 | m1 M1 | x | x | x | x | + N5 l5 | m1 M1 | x x | x | x | x | + N6 l6 | m1 M1 | x x x | x | x | x | + N7 l7 | m1 M1 | x x x x | x | x | x | + -- -- | -- M1 | | | | | ======================================================================================================================= """ # noqa: E501 - assert step > 1, "step should be larger than 1 in multi-step attention mask." - assert step <= 4, "Currently only a step of 4 or smaller is supported!" - s = attn_mask.shape[-1] - zero_mask = torch.ones_like(attn_mask).bool() - mask_2_1 = attn_mask.clone().detach() - mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attn_mask).bool() - for i in range(1, s - 1): - mask_2_2[:, :, i, i] = False - - if step == 2: - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2), dim=-1), - ), - dim=-2, - ) - return attn_mask - - mask_3_1 = mask_2_1.clone().detach() - mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] - mask_3_2 = mask_2_2.clone().detach() - mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True - mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True - - if step == 3: - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), - ), - dim=-2, - ) - return attn_mask - - mask_4_1 = mask_3_1.clone().detach() - mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] - mask_4_2 = mask_3_2.clone().detach() - mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] - mask_4_2[:, :, 2, 0] = True - mask_4_3 = mask_3_3.clone().detach() - mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] - mask_4_3[:, :, 2, 1] = True - mask_4_4 = mask_3_3.clone().detach() - mask_4_4[:, :, 2, 2] = True - - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), - torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), - ), - dim=-2, - ) + for step_idx in range(step): + mask_0 = attn_mask.clone().detach() + mask_0[:, :, step_idx, :] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() + for i in range(step_idx + 1, s - 1): + mask_1[:, :, i, i] = False + + attn_mask = torch.cat((mask_0, mask_1), dim=-1) + return attn_mask @@ -623,6 +533,7 @@ def forward( rotary_pos_emb: torch.Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + inference_context: StaticInferenceContext | None = None, extra_block_kwargs: dict | None = None, ) -> torch.Tensor: """Forward function.""" @@ -663,6 +574,7 @@ def forward( inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, packed_seq_params=packed_seq_params, + inference_context=inference_context, **(extra_block_kwargs or {}), ) @@ -870,14 +782,14 @@ def _get_eagle_module_inputs( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, - features: torch.Tensor | None = None, + ttt_step: int = 0, + parallel_draft_index: int = 0, ): """Getting EAGLE module inputs.""" - b = hidden_states.shape[1] - h = hidden_states.shape[2] - # [b, 1] - id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device) + id_padding = torch.zeros( + (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device + ) padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) @@ -889,184 +801,37 @@ def _get_eagle_module_inputs( eagle_inputs = {} - if self.eagle_config.parallel_draft_step > 1: - eagle_inputs["input_ids"] = padded_input_ids - eagle_inputs["position_ids"] = position_ids - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - else: - gathered_hidden_states = hidden_states - eagle_inputs["hidden_states"] = gathered_hidden_states - - for i in range(self.eagle_config.parallel_draft_step - 1): - eagle_inputs["input_ids"] = torch.cat( - ( - eagle_inputs["input_ids"], - torch.full( - padded_input_ids.shape, - getattr(self, f"mask_token_{i}"), - device=padded_input_ids.device, - dtype=padded_input_ids.dtype, - ), - ), - dim=-1, - ) - - eagle_inputs["hidden_states"] = torch.cat( - ( - eagle_inputs["hidden_states"], - torch.zeros( - (1 + i, b, h), dtype=hidden_states.dtype, device=hidden_states.device - ), - gathered_hidden_states[: -(1 + i)], - ), - dim=0, - ) - - eagle_inputs["position_ids"] = torch.cat( - (eagle_inputs["position_ids"], position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0 - ) - - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, self.eagle_config.parallel_draft_step - ) - elif features is None: - eagle_inputs["input_ids"] = padded_input_ids - eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = attn_mask - eagle_inputs["position_ids"] = position_ids - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - elif features.shape[0] == hidden_states.shape[0]: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids), - dim=-1, - ) - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, - ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 2) - eagle_inputs["position_ids"] = torch.cat((position_ids, position_ids), dim=-1) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=0) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - elif features.shape[0] == hidden_states.shape[0] * 2: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, - ) + eagle_inputs["position_ids"] = position_ids - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, + eagle_inputs["input_ids"] = ( + padded_input_ids + if parallel_draft_index == 0 + else torch.full( + padded_input_ids.shape, + getattr(self, f"mask_token_{parallel_draft_index - 1}"), + device=padded_input_ids.device, + dtype=padded_input_ids.dtype, ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 3) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - else: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, - ) - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, - ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 4) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids, position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None + ) eagle_inputs["embedding"] = self.embedding( input_ids=eagle_inputs["input_ids"], position_ids=eagle_inputs["position_ids"], ) + eagle_inputs["hidden_states"] = ( + hidden_states if parallel_draft_index == 0 else eagle_inputs["embedding"] + ) + + eagle_inputs["attention_mask"] = set_multi_step_attention_mask( + attn_mask, ttt_step + parallel_draft_index + ) + + eagle_inputs["rotary_pos_emb"] = torch.cat( + [rotary_pos_emb] * (ttt_step + parallel_draft_index + 1), + dim=0, + ) + return eagle_inputs def _compute_eagle_loss(self, logits, labels, eagle_logits): @@ -1159,6 +924,7 @@ def _eagle_forward( output_weight, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + inference_context: StaticInferenceContext | None = None, extra_block_kwargs: dict | None = None, ): eagle_hidden_states, eagle_hidden_states_pre_final_layernorm = self.eagle_module( @@ -1168,15 +934,24 @@ def _eagle_forward( eagle_inputs["rotary_pos_emb"], inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=inference_context, **(extra_block_kwargs or {}), ) + # Update inference_context.sequence_len_offset after each call of eagle_module + if inference_context is not None: + inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + if hasattr(self.eagle_module, "eagle_output_layer"): eagle_logits, _ = self.eagle_module.eagle_output_layer(eagle_hidden_states) else: eagle_logits, _ = self.output_layer(eagle_hidden_states, weight=output_weight) - return eagle_hidden_states, eagle_logits, eagle_hidden_states_pre_final_layernorm + return ( + eagle_hidden_states, + eagle_logits, + eagle_hidden_states_pre_final_layernorm, + ) def forward( self, @@ -1189,6 +964,7 @@ def forward( packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict | None = None, return_eagle_inputs: bool = False, + ttt_steps=4, **kwargs, ) -> torch.Tensor: if position_ids is None or attention_mask is None: @@ -1230,6 +1006,12 @@ def forward( output_weight = self.shared_embedding_or_output_weight() logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) + # EAGLE kv cache + eagle_inference_context = StaticInferenceContext( + input_ids.shape[0], + input_ids.shape[1] * (self.eagle_config.parallel_draft_step + ttt_steps - 1), + ) + if self.eagle_offline: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state @@ -1257,199 +1039,119 @@ def forward( hidden_states, apply_fc=True ) - # In calibration mode, we want to make sure all weights have been exercised. - # This makes sure all quantized weights have amax calibrated - if inference_params is None or self.calibration_mode: - eagle_inputs_0 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - _, eagle_logits_0, eagle_hidden_states_0_pre_norm = self._eagle_forward( - eagle_inputs_0, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - # If labels are not provided, return the original logits. We only return after - # all eagle weights have been exercised for quantization calibration purpose. - if labels is None: - return logits_sbh.transpose(0, 1).contiguous() - elif labels.shape[1] == input_ids.shape[1] - 1: - # For offline training, labels may be 1 token shorter than input_ids. - # We will just pad a 0 to the labels to make the seq_len the same as - # input_ids. This will introduce a small error in training if logit_distillation - # is False, and testing accuracy is wrong for the last token. - right_token_pad = torch.zeros( - (labels.shape[0], 1), - dtype=labels.dtype, - device=labels.device, - ) - labels = torch.cat((labels, right_token_pad), dim=-1) + if labels is not None: + if labels.shape[1] == input_ids.shape[1] - 1: + # For offline training, labels may be 1 token shorter than input_ids. + # We will just pad a 0 to the labels to make the seq_len the same as + # input_ids. This will introduce a small error in training if logit_distillation + # is False, and testing accuracy is wrong for the last token. + right_token_pad = torch.zeros( + (labels.shape[0], 1), + dtype=labels.dtype, + device=labels.device, + ) + labels = torch.cat((labels, right_token_pad), dim=-1) - # If eagle_freeze_base_model is set to True, - # the base model is frozen . - loss = self.compute_language_model_loss(labels, logits_sbh) - loss = 0.0 * loss + # If eagle_freeze_base_model is set to True, + # the base model is frozen . + loss = self.compute_language_model_loss(labels, logits_sbh) + if self.eagle_freeze_base_model: + loss = 0.0 * loss - if self.eagle_config.parallel_draft_step > 1: + acc = [] + for ttt_step in range(ttt_steps): + eagle_logits = [] for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] - loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) - loss_ = loss_[:, i:] - loss[:, i + 1 :] += 1.0 * loss_ - return loss - - loss_0 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_0) - loss[:, 1:] += self.eagle_loss_decay_factor * loss_0 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 1st Top-1: {acc}", - flush=True, + eagle_inputs = self._get_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + ttt_step=ttt_step, + parallel_draft_index=i, ) - # Second round of EAGLE loss - eagle_inputs_1 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_0_pre_norm, - ) - - _, eagle_logits_2x, eagle_hidden_states_2x_pre_norm = self._eagle_forward( - eagle_inputs_1, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - eagle_logits_1 = eagle_logits_2x[-labels.shape[1] :, :, :] - - loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1) - # [b, s - 2] - loss_1 = loss_1[:, 1:] - loss[:, 2:] += self.eagle_loss_decay_factor**2 * loss_1 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_1[1:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 2:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 2nd Top-1: {acc}", - flush=True, + _, eagle_logits_, eagle_hidden_states_pre_norm_ = self._eagle_forward( + eagle_inputs, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), ) - # Third EAGLE loss - eagle_inputs_2 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_2x_pre_norm, - ) - - _, eagle_logits_3x, eagle_hidden_states_3x_pre_norm = self._eagle_forward( - eagle_inputs_2, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) + if i == 0: + next_eagle_hidden_states_pre_norm = eagle_hidden_states_pre_norm_ - eagle_logits_2 = eagle_logits_3x[-labels.shape[1] :, :, :] - - loss_2 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_2) - # [b, s - 3] - loss_2 = loss_2[:, 2:] - loss[:, 3:] += self.eagle_loss_decay_factor**3 * loss_2 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_2[2:-1, :, :] + eagle_logits.append(eagle_logits_) + eagle_logits = torch.cat(eagle_logits, dim=0) + eagle_module_input_hidden_states = next_eagle_hidden_states_pre_norm + if self.config.sequence_parallel: + eagle_module_input_hidden_states = gather_from_sequence_parallel_region( + eagle_module_input_hidden_states ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 3:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 3rd Top-1: {acc}", - flush=True, + eagle_module_input_hidden_states = torch.cat( + ( + torch.zeros( + ( + 1, + eagle_module_input_hidden_states.shape[1], + eagle_module_input_hidden_states.shape[2], + ), + dtype=eagle_module_input_hidden_states.dtype, + device=eagle_module_input_hidden_states.device, + ), + eagle_module_input_hidden_states[:-1, :, :], + ) + ) + if self.config.sequence_parallel: + eagle_module_input_hidden_states = scatter_to_sequence_parallel_region( + eagle_module_input_hidden_states ) - # Forth EAGLE loss - eagle_inputs_3 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_3x_pre_norm, - ) - - _, eagle_logits_4x, eagle_hidden_states_4x_pre_norm = self._eagle_forward( - eagle_inputs_3, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - eagle_logits_3 = eagle_logits_4x[-labels.shape[1] :, :, :] + # Discard kv cache for the last parallel_draft_step - 1 tokens + # as the next ttt_step will only base on the first token in the + # current ttt_step + eagle_inference_context.sequence_len_offset -= input_ids.shape[1] * ( + self.eagle_config.parallel_draft_step - 1 + ) - loss_3 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_3) - # [b, s - 4] - loss_3 = loss_3[:, 3:] - loss[:, 4:] += self.eagle_loss_decay_factor**4 * loss_3 + # If labels are not provided, return the original logits. We only return after + # all eagle weights have been exercised for quantization calibration purpose. + if labels is None: + return logits_sbh.transpose(0, 1).contiguous() - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_3[3:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 4:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 4th Top-1: {acc}", - flush=True, + for i in range(self.eagle_config.parallel_draft_step): + eagle_logit = eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] + loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logit) + loss_ = loss_[:, i + ttt_step :] + loss[:, i + ttt_step + 1 :] += ( + self.eagle_loss_decay_factor ** (ttt_step + i) * loss_ ) + if self.eagle_report_acc and not self.training: + with torch.no_grad(): + for i in range(self.eagle_config.parallel_draft_step): + gathered_logits = gather_from_tensor_model_parallel_region( + eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] + ) + gathered_logits = gathered_logits[i + ttt_step : -1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = ( + torch.eq(labels[:, i + ttt_step + 1 :], eagle_top1).sum() + / eagle_top1.numel() + ) + acc.append(top1_p) + + if self.eagle_report_acc and not self.training and get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}" + f"EAGLE Top-1: {acc}", + flush=True, + ) + return loss def tree_decode(self, input_ids: torch.Tensor, tree: Tree): @@ -1630,6 +1332,8 @@ def pseudo_speculative_generate( ): """Pseudo generate of the EAGLE GPTModel. + This function does not support kv cache as sequence parallel may be enabled. + Returns: base_token (torch.Tensor): token from base model draft_tokens (torch.Tensor): draft tokens from eagle module @@ -1678,17 +1382,16 @@ def pseudo_speculative_generate( draft_tokens = [] for _ in range(steps): - if self.eagle_config.parallel_draft_step > 1: - for i in range(self.eagle_config.parallel_draft_step - 1): - eagle_ids = torch.cat( - (eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1 - ) - hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0) + for i in range(self.eagle_config.parallel_draft_step - 1): + eagle_ids = torch.cat( + (eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1 + ) + # Pad dummy hidden_states for mask tokens + # They will be replaced by embeddings after padding + hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0) padded_eagle_ids, seq_len, padded_hidden_states = right_padding( eagle_ids, hidden_states ) - if self.config.sequence_parallel: - padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states) eagle_attention_mask, eagle_position_ids = get_default_attention_mask_and_position_ids( padded_eagle_ids ) @@ -1699,6 +1402,19 @@ def pseudo_speculative_generate( input_ids=padded_eagle_ids, position_ids=eagle_position_ids, ) + if self.config.sequence_parallel: + gathered_embedding = gather_from_sequence_parallel_region(eagle_inputs["embedding"]) + else: + gathered_embedding = eagle_inputs["embedding"] + if self.eagle_config.parallel_draft_step > 1: + # Replace dummy hidden_states with embedding for mask tokens + padded_hidden_states[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] = gathered_embedding[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] + if self.config.sequence_parallel: + padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states) eagle_inputs["hidden_states"] = padded_hidden_states eagle_inputs["attention_mask"] = eagle_attention_mask @@ -1717,31 +1433,31 @@ def pseudo_speculative_generate( ) eagle_next_hidden_states_input = eagle_next_hidden_states_input[:seq_len, :, :] - if self.eagle_config.parallel_draft_step > 1: - draft_token = ( - gather_from_tensor_model_parallel_region(eagle_logits)[ - -self.eagle_config.parallel_draft_step :, :, : - ] - .argmax(dim=-1) - .transpose(0, 1) - ) - else: - draft_token = ( - gather_from_tensor_model_parallel_region(eagle_logits)[-1:, :, :] - .argmax(dim=-1) - .transpose(0, 1) - ) + draft_token = ( + gather_from_tensor_model_parallel_region(eagle_logits)[ + -self.eagle_config.parallel_draft_step :, :, : + ] + .argmax(dim=-1) + .transpose(0, 1) + ) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: draft_token += self.eagle_module.d2t[draft_token] - if self.eagle_config.parallel_draft_step > 1: - return base_token, draft_token - draft_tokens.append(draft_token) + # Remove mask tokens from eagle_ids before adding draft_token + # Remove added hidden_states before + if self.eagle_config.parallel_draft_step > 1: + eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1] + hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1] + eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1) hidden_states = torch.cat( - (hidden_states, eagle_next_hidden_states_input[-1:, :, :]), dim=0 + ( + hidden_states, + eagle_next_hidden_states_input[-self.eagle_config.parallel_draft_step :, :, :], + ), + dim=0, ) draft_tokens = torch.cat(draft_tokens, dim=-1) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 648cc8163..1d9a2d5a6 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -290,11 +290,13 @@ def check_draft(self, ground_truth, input_ids, draft_tokens): return input_ids - def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=True): + def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=False): """This function checks the data consistency across all ranks in the group. Use rank 0 data as the golden set to broadcast to all ranks. - Each rank will then compare to this data and through error if different. + Each rank compares its data against this golden set and either raises + (when fail_when_mismatch=True) or emits a warning while forcing every + rank to adopt rank 0's data. """ if not torch.distributed.is_initialized(): return data @@ -346,16 +348,12 @@ def validate( if tree_paths: input_id, draft_tokens, pred_tokens = self.model.tree_decode(input_ids, tree=tree) - pred_tokens = self.check_data_consistency_across_ranks( - pred_tokens, fail_when_mismatch=False - ) + pred_tokens = self.check_data_consistency_across_ranks(pred_tokens) else: input_id, draft_tokens = self.model.pseudo_speculative_generate( input_ids, steps=steps ) - draft_tokens = self.check_data_consistency_across_ranks( - draft_tokens, fail_when_mismatch=False - ) + draft_tokens = self.check_data_consistency_across_ranks(draft_tokens) input_id = self.check_data_consistency_across_ranks(input_id) input_ids = torch.cat((input_ids, input_id), dim=-1)