diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index 69051f6e8..9f67efb92 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -47,5 +47,4 @@ def modify( self.eagle_loss_decay_factor = eagle_loss_decay_factor if eagle_architecture_config.get("parallel_draft_step", 1) > 1: - for i in range(eagle_architecture_config.get("parallel_draft_step") - 1): - self.register_buffer(f"mask_token_{i}", torch.tensor(-1)) + self.register_buffer("mask_token", torch.tensor(-1)) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2a0e63a3c..bc9479752 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -26,6 +26,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel @@ -139,7 +140,7 @@ def dict_to_config( def mcore_version_higher_than(target_version: str): - """Check if megatron-core is least this version.""" + """Check if megatron-core is greater than this version.""" return Version(megatron.core.__version__) > Version(target_version) @@ -195,22 +196,73 @@ def set_multi_step_attention_mask(attn_mask, step): l0 l1 l2 l3 l4 l5 l6 l7 (base labels) - (1st) | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | + ttt_step=0 + | i1 i2 i3 i4 i5 i6 i7 -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | ========================================= - f1 l1 | i1 h0 | x | - f2 l2 | i2 h1 | x x | - f3 l3 | i3 h2 | x x x | - f4 l4 | i4 h3 | x x x x | - f5 l5 | i5 h4 | x x x x x | - f6 l6 | i6 h5 | x x x x x x | - f7 l7 | i7 h6 | x x x x x x x | + F1 l1 | i1 h0 | x | + F2 l2 | i2 h1 | x x | + F3 l3 | i3 h2 | x x x | + F4 l4 | i4 h3 | x x x x | + F5 l5 | i5 h4 | x x x x x | + F6 l6 | i6 h5 | x x x x x x | + F7 l7 | i7 h6 | x x x x x x x | -- -- | -- h7 | o o o o o o o o | ========================================= - (2nd) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | + ttt_step=1 + | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | + =================================================================== + -- -- | i1 -- | | | + J2 l2 | i2 F1 | x o | x | + J3 l3 | i3 F2 | x x o | x | + J4 l4 | i4 F3 | x x x o | x | + J5 l5 | i5 F4 | x x x x o | x | + J6 l6 | i6 F5 | x x x x x o | x | + J7 l7 | i7 F6 | x x x x x x o | x | + -- -- | -- F7 | | | + =================================================================== + -- -- | m0 M0 | | | + -- -- | m0 M0 | | | + K3 l3 | m0 M0 | x o o | x o | + K4 l4 | m0 M0 | x x o o | x o | + K5 l5 | m0 M0 | x x x o o | x o | + K6 l6 | m0 M0 | x x x x o o | x o | + K7 l7 | m0 M0 | x x x x x o o | x o | + -- -- | -- M0 | | | + =================================================================== + """ + s = attn_mask.shape[-1] + for step_idx in range(step): + mask_0 = attn_mask.clone().detach() + mask_0[:, :, step_idx, :] = True + mask_0[:, :, :, :-1] = mask_0[:, :, :, 1:] + mask_1 = attn_mask.new_ones(attn_mask.shape[0], attn_mask.shape[1], s, s).bool() + for i in range(step_idx + 1, s - 1): + mask_1[:, :, i, i] = False + + attn_mask = torch.cat((mask_0, mask_1), dim=-1) + + return attn_mask + + +def set_diffusion_attention_mask(attn_mask, step, block_len=1): + """Given an original attention_mask, construct a multi-step attention_mask. + + i0 i1 i2 i3 i4 i5 i6 i7 (base input_ids) + ======================= + h0 h1 h2 h3 h4 h5 h6 h7 (base hidden_states) + l0 l1 l2 l3 l4 l5 l6 l7 (base labels) + + + + + ttt_step=0 + block_len=2 + | i1 i2 i3 i4 i5 i6 i7 -- | i1 mm i3 mm mm mm i7 -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | i1 mm i3 mm mm mm i7 -- | =================================================================== F1 l1 | i1 h0 | x | | F2 l2 | i2 h1 | x x | | @@ -221,155 +273,98 @@ def set_multi_step_attention_mask(attn_mask, step): F7 l7 | i7 h6 | x x x x x x x | | -- -- | -- h7 | o o o o o o o o | | =================================================================== - -- -- | i1 -- | | | - G2 l2 | i2 F1 | x o | x | - G3 l3 | i3 F2 | x x o | x | - G4 l4 | i4 F3 | x x x o | x | - G5 l5 | i5 F4 | x x x x o | x | - G6 l6 | i6 F5 | x x x x x o | x | - G7 l7 | i7 F6 | x x x x x x o | x | - -- -- | -- F7 | | | + -- -- | i1 i1 | o | o o | + -- l1 | mm mm | o | o o | + -- -- | i3 i3 | x x o | x x | + -- l3 | mm mm | x x o | x x | + -- l4 | mm mm | x x x x o | x x | + -- l5 | mm mm | x x x x o | x x | + -- -- | i7 i7 | x x x x x x o | x x | + -- -- | -- -- | | | =================================================================== - (3rd) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- -- G2 G3 G4 G5 G6 G7 | - ============================================================================================= - F1 l1 | i1 h0 | x | | | - F2 l2 | i2 h1 | x x | | | - F3 l3 | i3 h2 | x x x | | | - F4 l4 | i4 h3 | x x x x | | | - F5 l5 | i5 h4 | x x x x x | | | - F6 l6 | i6 h5 | x x x x x x | | | - F7 l7 | i7 h6 | x x x x x x x | | | - -- -- | -- h7 | o o o o o o o o | | | + ttt_step=1 + | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | mm i2 mm mm i5 mm mm -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | mm i2 mm mm i5 mm mm -- | ============================================================================================= -- -- | i1 -- | | | | - G2 l2 | i2 F1 | x o | x | | - G3 l3 | i3 F2 | x x o | x | | - G4 l4 | i4 F3 | x x x o | x | | - G5 l5 | i5 F4 | x x x x o | x | | - G6 l6 | i6 F5 | x x x x x o | x | | - G7 l7 | i7 F6 | x x x x x x o | x | | + J2 l2 | i2 F1 | x o | x | | + J3 l3 | i3 F2 | x x o | x | | + J4 l4 | i4 F3 | x x x o | x | | + J5 l5 | i5 F4 | x x x x o | x | | + J6 l6 | i6 F5 | x x x x x o | x | | + J7 l7 | i7 F6 | x x x x x x o | x | | -- -- | -- F7 | | | | ============================================================================================= - -- -- | i1 -- | | | | - -- -- | i2 -- | | | | - H3 l3 | i3 G2 | x o o | x o | x | - H4 l4 | i4 G3 | x x o o | x o | x | - H5 l5 | i5 G4 | x x x o o | x o | x | - H6 l6 | i6 G5 | x x x x o o | x o | x | - H7 l7 | i7 G6 | x x x x x o o | x o | x | - -- -- | -- G7 | | | | + -- l0 | mm mm | | | | + -- -- | i2 i2 | | | | + -- l2 | mm mm | x o | x | x x | + -- l3 | mm mm | x o | x | x x | + -- -- | i5 i5 | x x x o | x | x x | + -- l5 | mm mm | x x x o | x | x x | + -- l6 | mm mm | x x x x x o | x | x x | + -- -- | -- -- | | | | ============================================================================================= - (4th) | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | - (out) | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- -- G2 G3 G4 G5 G6 G7 | -- -- -- H3 H4 H5 H6 H7 | - ======================================================================================================================= - F1 l1 | i1 h0 | x | | | | - F2 l2 | i2 h1 | x x | | | | - F3 l3 | i3 h2 | x x x | | | | - F4 l4 | i4 h3 | x x x x | | | | - F5 l5 | i5 h4 | x x x x x | | | | - F6 l6 | i6 h5 | x x x x x x | | | | - F7 l7 | i7 h6 | x x x x x x x | | | | - -- -- | -- h7 | o o o o o o o o | | | | + ttt_step=2 + | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 i2 i3 i4 i5 i6 i7 -- | i1 mm mm i4 i5 mm mm -- | + | h0 h1 h2 h3 h4 h5 h6 h7 | -- F1 F2 F3 F4 F5 F6 F7 | -- -- J2 J4 J5 J6 J7 -- | i1 mm mm i4 i5 mm mm -- | ======================================================================================================================= -- -- | i1 -- | | | | | - G2 l2 | i2 F1 | x o | x | | | - G3 l3 | i3 F2 | x x o | x | | | - G4 l4 | i4 F3 | x x x o | x | | | - G5 l5 | i5 F4 | x x x x o | x | | | - G6 l6 | i6 F5 | x x x x x o | x | | | - G7 l7 | i7 F6 | x x x x x x o | x | | | + J2 l2 | i2 F1 | | | | | + J3 l3 | i3 F2 | x o | x | x | | + J4 l4 | i4 F3 | x x o | x | x | | + J5 l5 | i5 F4 | x x x o | x | x | | + J6 l6 | i6 F5 | x x x x o | x | x | | + J7 l7 | i7 F6 | x x x x x o | x | x | | -- -- | -- F7 | | | | | ======================================================================================================================= - -- -- | i1 -- | | | | | - -- -- | i2 -- | | | | | - H3 l3 | i3 G2 | x o o | x o | x | | - H4 l4 | i4 G3 | x x o o | x o | x | | - H5 l5 | i5 G4 | x x x o o | x o | x | | - H6 l6 | i6 G5 | x x x x o o | x o | x | | - H7 l7 | i7 G6 | x x x x x o o | x o | x | | - -- -- | -- G7 | | | | | - ======================================================================================================================= - -- -- | i1 -- | | | | | - -- -- | i2 -- | | | | | - -- -- | i3 -- | | | | | - K4 l4 | i4 H3 | x | x | x | x | - K5 l5 | i5 H4 | x x | x | x | x | - K6 l6 | i6 H5 | x x x | x | x | x | - K7 l7 | i7 H6 | x x x x | x | x | x | - -- -- | -- H7 | | | | | + -- l0 | mm mm | | | | | + -- -- | i2 i2 | | | | | + -- l2 | mm mm | | | | | + -- l3 | mm mm | | | | | + -- -- | i5 i5 | x x o | x | x | x x | + -- l5 | mm mm | x x o | x | x | x x | + -- l6 | mm mm | x x x x o | x | x | x x | + -- -- | -- -- | | | | | ======================================================================================================================= """ # noqa: E501 - assert step > 1, "step should be larger than 1 in multi-step attention mask." - assert step <= 4, "Currently only a step of 4 or smaller is supported!" - s = attn_mask.shape[-1] - zero_mask = torch.ones_like(attn_mask).bool() - mask_2_1 = attn_mask.clone().detach() - mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attn_mask).bool() - for i in range(1, s - 1): - mask_2_2[:, :, i, i] = False - - if step == 2: - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2), dim=-1), - ), - dim=-2, - ) - return attn_mask - - mask_3_1 = mask_2_1.clone().detach() - mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] - mask_3_2 = mask_2_2.clone().detach() - mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True - mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True - - if step == 3: - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), - ), - dim=-2, - ) - return attn_mask - - mask_4_1 = mask_3_1.clone().detach() - mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] - mask_4_2 = mask_3_2.clone().detach() - mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] - mask_4_2[:, :, 2, 0] = True - mask_4_3 = mask_3_3.clone().detach() - mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] - mask_4_3[:, :, 2, 1] = True - mask_4_4 = mask_3_3.clone().detach() - mask_4_4[:, :, 2, 2] = True - - attn_mask = torch.cat( - ( - torch.cat((attn_mask, zero_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), - torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), - ), - dim=-2, - ) - return attn_mask + idx = torch.arange(s, device=attn_mask.device) + i = idx.view(s, 1) # (seq_len, 1) + j = idx.view(1, s) # (1, seq_len) + k = block_len * (i // block_len) # (seq_len, 1) + valid = k >= step + 1 # (seq_len, 1) + layers = [] + # First layer: True for columns 0 to k(i)-step if k(i) >= step+1, else all False + first_layer = (j < (k - step)) & valid + layers.append(first_layer) + # Middle layers: for layer l in step ... 1, True only at k(i)-l if k(i) >= step+1, else all False + for l in range(step, 0, -1): + col = k - l + middle_layer = (j == col) & valid + layers.append(middle_layer) + # Last layer: True for columns k(i) to k(i)+block_size-1 (inclusive) if k(i) >= step+1, else all False + last_layer = (j >= k) & (j < (k + block_len)) + layers.append(last_layer) + mask = torch.cat(layers, dim=1) + # Revert mask + mask = ~mask + mask = mask.repeat(attn_mask.shape[0], attn_mask.shape[1], 1, 1) + + return mask class EagleLanguageModelEmbedding(LanguageModelEmbedding): """Allow last pp stage to also load the embedding.""" + def __init__(self, extra_embedding, *args, **kwargs): + """If extra_embedding is False, this is just a replica of base model LanguageModelEmbedding.""" + super().__init__(*args, **kwargs) + self.extra_embedding = extra_embedding + def sharded_state_dict( self, prefix: str = "", @@ -387,7 +382,11 @@ def sharded_state_dict( allow_shape_mismatch=True, prepend_offsets=sharded_offsets, # (PP, TP, DP) - replica_id=(1, 0, get_data_parallel_rank(with_context_parallel=True)), + replica_id=( + 0 if self.extra_embedding else 1, + 0, + get_data_parallel_rank(with_context_parallel=True), + ), ) } @@ -570,6 +569,16 @@ def __init__( skip_weight_param_allocation=False, ) + if self.config.parallel_draft_step > 1: + self.embedding = EagleLanguageModelEmbedding( + extra_embedding=True, + config=self.config, + vocab_size=self.config.vocab_size + + self.config.tensor_model_parallel_size, # for mask token + max_sequence_length=self.config.max_sequence_length, + position_embedding_type=self.config.position_embedding_type, + ) + def _get_eagle_transformer_layer_spec(self, config): """Get the TransformerLayer implementation spec. @@ -623,6 +632,7 @@ def forward( rotary_pos_emb: torch.Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + inference_context: StaticInferenceContext | None = None, extra_block_kwargs: dict | None = None, ) -> torch.Tensor: """Forward function.""" @@ -663,6 +673,7 @@ def forward( inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, packed_seq_params=packed_seq_params, + inference_context=inference_context, **(extra_block_kwargs or {}), ) @@ -784,6 +795,7 @@ def modify( if not self.pre_process and self.post_process: self.embedding = EagleLanguageModelEmbedding( + extra_embedding=False, config=self.config, vocab_size=self.vocab_size, max_sequence_length=self.max_sequence_length, @@ -870,14 +882,15 @@ def _get_eagle_module_inputs( hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, - features: torch.Tensor | None = None, + ttt_step: int = 0, + diffusion: bool = False, + max_block_size: int = 8, ): """Getting EAGLE module inputs.""" - b = hidden_states.shape[1] - h = hidden_states.shape[2] - # [b, 1] - id_padding = torch.zeros((b, 1), dtype=input_ids.dtype, device=input_ids.device) + id_padding = torch.zeros( + (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device + ) padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) @@ -889,187 +902,50 @@ def _get_eagle_module_inputs( eagle_inputs = {} - if self.eagle_config.parallel_draft_step > 1: - eagle_inputs["input_ids"] = padded_input_ids - eagle_inputs["position_ids"] = position_ids - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - else: - gathered_hidden_states = hidden_states - eagle_inputs["hidden_states"] = gathered_hidden_states + eagle_inputs["input_ids"] = padded_input_ids + eagle_inputs["position_ids"] = position_ids - for i in range(self.eagle_config.parallel_draft_step - 1): - eagle_inputs["input_ids"] = torch.cat( - ( - eagle_inputs["input_ids"], - torch.full( - padded_input_ids.shape, - getattr(self, f"mask_token_{i}"), - device=padded_input_ids.device, - dtype=padded_input_ids.dtype, - ), - ), - dim=-1, - ) - - eagle_inputs["hidden_states"] = torch.cat( - ( - eagle_inputs["hidden_states"], - torch.zeros( - (1 + i, b, h), dtype=hidden_states.dtype, device=hidden_states.device - ), - gathered_hidden_states[: -(1 + i)], - ), - dim=0, - ) - - eagle_inputs["position_ids"] = torch.cat( - (eagle_inputs["position_ids"], position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (eagle_inputs["rotary_pos_emb"], rotary_pos_emb), dim=0 - ) - - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask( - attn_mask, self.eagle_config.parallel_draft_step + if self.eagle_config.parallel_draft_step > 1: + eagle_inputs["embedding"] = self.eagle_module.embedding( + input_ids=eagle_inputs["input_ids"], + position_ids=eagle_inputs["position_ids"], ) - elif features is None: - eagle_inputs["input_ids"] = padded_input_ids - eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = attn_mask - eagle_inputs["position_ids"] = position_ids - eagle_inputs["rotary_pos_emb"] = rotary_pos_emb - elif features.shape[0] == hidden_states.shape[0]: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids), - dim=-1, + else: + eagle_inputs["embedding"] = self.embedding( + input_ids=eagle_inputs["input_ids"], + position_ids=eagle_inputs["position_ids"], ) + eagle_inputs["hidden_states"] = hidden_states - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, - ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) + eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, ttt_step) - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 2) - eagle_inputs["position_ids"] = torch.cat((position_ids, position_ids), dim=-1) + eagle_inputs["rotary_pos_emb"] = rotary_pos_emb.repeat(ttt_step + 1, 1, 1, 1) - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=0) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - elif features.shape[0] == hidden_states.shape[0] * 2: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, + # Update eagle_inputs for diffusion + if diffusion: + block_len = torch.randint(1, max_block_size + 1, (1,)).item() + threshold = torch.rand(size=(input_ids.shape[0], 1), device=input_ids.device) + input_ids_mask = torch.rand(size=input_ids.shape, device=input_ids.device) > threshold + eagle_inputs["block_len"] = block_len + eagle_inputs["diffusion_loss_mask"] = input_ids_mask + eagle_inputs["input_ids"] = torch.where( + condition=input_ids_mask, input=self.mask_token, other=eagle_inputs["input_ids"] ) - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, + eagle_inputs["embedding"] = self.eagle_module.embedding( + input_ids=eagle_inputs["input_ids"], + position_ids=eagle_inputs["position_ids"], ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 3) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids), dim=-1 + eagle_inputs["hidden_states"] = eagle_inputs["embedding"] + eagle_inputs["attention_mask"] = set_diffusion_attention_mask( + attn_mask, + ttt_step, + block_len, ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - else: - eagle_inputs["input_ids"] = torch.cat( - (padded_input_ids, padded_input_ids, padded_input_ids, padded_input_ids), - dim=-1, - ) - - if self.config.sequence_parallel: - gathered_hidden_states = gather_from_sequence_parallel_region(hidden_states) - gathered_features = gather_from_sequence_parallel_region(features) - else: - gathered_hidden_states = hidden_states - gathered_features = features - eagle_inputs["hidden_states"] = torch.cat( - ( - gathered_hidden_states, - torch.zeros((1, b, h), dtype=hidden_states.dtype, device=hidden_states.device), - gathered_features[:-1, :, :], - ), - dim=0, - ) - if self.config.sequence_parallel: - eagle_inputs["hidden_states"] = scatter_to_sequence_parallel_region( - eagle_inputs["hidden_states"] - ) - - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, 4) - eagle_inputs["position_ids"] = torch.cat( - (position_ids, position_ids, position_ids, position_ids), dim=-1 - ) - - if rotary_pos_emb is not None: - eagle_inputs["rotary_pos_emb"] = torch.cat( - (rotary_pos_emb, rotary_pos_emb, rotary_pos_emb, rotary_pos_emb), - dim=0, - ) - else: - # [TODO] (yeyu): there will be problem here with MLA - eagle_inputs["rotary_pos_emb"] = None - - eagle_inputs["embedding"] = self.embedding( - input_ids=eagle_inputs["input_ids"], - position_ids=eagle_inputs["position_ids"], - ) + eagle_inputs["rotary_pos_emb"] = rotary_pos_emb.repeat(ttt_step + 2, 1, 1, 1) return eagle_inputs - def _compute_eagle_loss(self, logits, labels, eagle_logits): + def _compute_eagle_loss(self, logits, labels, eagle_logits, shift_labels: bool = True): """Compute the total loss for EAGLE. logits: [s, b, vocab // TP] @@ -1079,9 +955,14 @@ def _compute_eagle_loss(self, logits, labels, eagle_logits): # Compute lm loss (classification loss) or KLDivergence if self.eagle_self_logit_distillation: mapping = self.eagle_module.d2t if hasattr(self.eagle_module, "d2t") else None - token_loss = self.kld(eagle_logits[:-1, :, :], logits[1:, :, :], mapping) - else: + if shift_labels: + token_loss = self.kld(eagle_logits[:-1, :, :], logits[1:, :, :], mapping) + else: + token_loss = self.kld(eagle_logits[:-1, :, :], logits[:-1, :, :], mapping) + elif shift_labels: token_loss = self.compute_language_model_loss(labels[:, 1:], eagle_logits[:-1, :, :]) + else: + token_loss = self.compute_language_model_loss(labels[:, :-1], eagle_logits[:-1, :, :]) # [b, s - 1] return token_loss @@ -1159,7 +1040,9 @@ def _eagle_forward( output_weight, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, + inference_context: StaticInferenceContext | None = None, extra_block_kwargs: dict | None = None, + update_sequence_len_offset: bool = True, ): eagle_hidden_states, eagle_hidden_states_pre_final_layernorm = self.eagle_module( eagle_inputs["embedding"], @@ -1168,15 +1051,24 @@ def _eagle_forward( eagle_inputs["rotary_pos_emb"], inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=inference_context, **(extra_block_kwargs or {}), ) + # Update inference_context.sequence_len_offset after each call of eagle_module + if inference_context is not None and update_sequence_len_offset: + inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + if hasattr(self.eagle_module, "eagle_output_layer"): eagle_logits, _ = self.eagle_module.eagle_output_layer(eagle_hidden_states) else: eagle_logits, _ = self.output_layer(eagle_hidden_states, weight=output_weight) - return eagle_hidden_states, eagle_logits, eagle_hidden_states_pre_final_layernorm + return ( + eagle_hidden_states, + eagle_logits, + eagle_hidden_states_pre_final_layernorm, + ) def forward( self, @@ -1189,6 +1081,7 @@ def forward( packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict | None = None, return_eagle_inputs: bool = False, + ttt_steps=4, **kwargs, ) -> torch.Tensor: if position_ids is None or attention_mask is None: @@ -1230,6 +1123,12 @@ def forward( output_weight = self.shared_embedding_or_output_weight() logits_sbh, _ = self.output_layer(hidden_states, weight=output_weight) + # EAGLE kv cache + eagle_inference_context = StaticInferenceContext( + input_ids.shape[0], + input_ids.shape[1] * (ttt_steps + 1), + ) + if self.eagle_offline: eagle_module_input_hidden_states = self._get_eagle_input_hidden_states( aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state @@ -1257,198 +1156,150 @@ def forward( hidden_states, apply_fc=True ) - # In calibration mode, we want to make sure all weights have been exercised. - # This makes sure all quantized weights have amax calibrated - if inference_params is None or self.calibration_mode: - eagle_inputs_0 = self._get_eagle_module_inputs( + if labels is not None: + if labels.shape[1] == input_ids.shape[1] - 1: + # For offline training, labels may be 1 token shorter than input_ids. + # We will just pad a 0 to the labels to make the seq_len the same as + # input_ids. This will introduce a small error in training if logit_distillation + # is False, and testing accuracy is wrong for the last token. + right_token_pad = torch.zeros( + (labels.shape[0], 1), + dtype=labels.dtype, + device=labels.device, + ) + labels = torch.cat((labels, right_token_pad), dim=-1) + + # If eagle_freeze_base_model is set to True, + # the base model is frozen . + loss = self.compute_language_model_loss(labels, logits_sbh) + if self.eagle_freeze_base_model: + loss = 0.0 * loss + + acc = [] + for ttt_step in range(ttt_steps): + eagle_logits = [] + + eagle_inputs = self._get_eagle_module_inputs( input_ids=input_ids, hidden_states=eagle_module_input_hidden_states, attention_mask=attention_mask, position_ids=position_ids, + ttt_step=ttt_step, ) - _, eagle_logits_0, eagle_hidden_states_0_pre_norm = self._eagle_forward( - eagle_inputs_0, + _, eagle_logits_, eagle_module_input_hidden_states = self._eagle_forward( + eagle_inputs, output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) - # If labels are not provided, return the original logits. We only return after - # all eagle weights have been exercised for quantization calibration purpose. - if labels is None: - return logits_sbh.transpose(0, 1).contiguous() - elif labels.shape[1] == input_ids.shape[1] - 1: - # For offline training, labels may be 1 token shorter than input_ids. - # We will just pad a 0 to the labels to make the seq_len the same as - # input_ids. This will introduce a small error in training if logit_distillation - # is False, and testing accuracy is wrong for the last token. - right_token_pad = torch.zeros( - (labels.shape[0], 1), - dtype=labels.dtype, - device=labels.device, - ) - labels = torch.cat((labels, right_token_pad), dim=-1) - - # If eagle_freeze_base_model is set to True, - # the base model is frozen . - loss = self.compute_language_model_loss(labels, logits_sbh) - loss = 0.0 * loss + eagle_logits.append(eagle_logits_) - if self.eagle_config.parallel_draft_step > 1: - for i in range(self.eagle_config.parallel_draft_step): - eagle_logits = eagle_logits_0[i * labels.shape[1] : (i + 1) * labels.shape[1]] - loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logits) - loss_ = loss_[:, i:] - loss[:, i + 1 :] += 1.0 * loss_ - return loss - - loss_0 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_0) - loss[:, 1:] += self.eagle_loss_decay_factor * loss_0 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_0[:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 1:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 1st Top-1: {acc}", - flush=True, + if self.eagle_config.parallel_draft_step > 1: + # Diffusion training within EAGLE module + eagle_inputs = self._get_eagle_module_inputs( + input_ids=input_ids, + hidden_states=eagle_module_input_hidden_states, # Not used in diffusion + attention_mask=attention_mask, + position_ids=position_ids, + ttt_step=ttt_step, + diffusion=True, + max_block_size=self.eagle_config.parallel_draft_step - 1, ) - # Second round of EAGLE loss - eagle_inputs_1 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_0_pre_norm, - ) - - _, eagle_logits_2x, eagle_hidden_states_2x_pre_norm = self._eagle_forward( - eagle_inputs_1, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - eagle_logits_1 = eagle_logits_2x[-labels.shape[1] :, :, :] - - loss_1 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_1) - # [b, s - 2] - loss_1 = loss_1[:, 1:] - loss[:, 2:] += self.eagle_loss_decay_factor**2 * loss_1 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_1[1:-1, :, :] - ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 2:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 2nd Top-1: {acc}", - flush=True, + _, eagle_logits_, _ = self._eagle_forward( + eagle_inputs, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + update_sequence_len_offset=False, + **(extra_block_kwargs or {}), ) - # Third EAGLE loss - eagle_inputs_2 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_2x_pre_norm, - ) + eagle_logits.append(eagle_logits_) - _, eagle_logits_3x, eagle_hidden_states_3x_pre_norm = self._eagle_forward( - eagle_inputs_2, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - eagle_logits_2 = eagle_logits_3x[-labels.shape[1] :, :, :] - - loss_2 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_2) - # [b, s - 3] - loss_2 = loss_2[:, 2:] - loss[:, 3:] += self.eagle_loss_decay_factor**3 * loss_2 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_2[2:-1, :, :] + if self.config.sequence_parallel: + eagle_module_input_hidden_states = gather_from_sequence_parallel_region( + eagle_module_input_hidden_states ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 3:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 3rd Top-1: {acc}", - flush=True, + eagle_module_input_hidden_states = torch.cat( + ( + torch.zeros( + ( + 1, + eagle_module_input_hidden_states.shape[1], + eagle_module_input_hidden_states.shape[2], + ), + dtype=eagle_module_input_hidden_states.dtype, + device=eagle_module_input_hidden_states.device, + ), + eagle_module_input_hidden_states[:-1, :, :], ) - - # Forth EAGLE loss - eagle_inputs_3 = self._get_eagle_module_inputs( - input_ids=input_ids, - hidden_states=eagle_module_input_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - features=eagle_hidden_states_3x_pre_norm, - ) - - _, eagle_logits_4x, eagle_hidden_states_4x_pre_norm = self._eagle_forward( - eagle_inputs_3, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - eagle_logits_3 = eagle_logits_4x[-labels.shape[1] :, :, :] - - loss_3 = self._compute_eagle_loss(logits_sbh, labels, eagle_logits_3) - # [b, s - 4] - loss_3 = loss_3[:, 3:] - loss[:, 4:] += self.eagle_loss_decay_factor**4 * loss_3 - - if self.eagle_report_acc and not self.training: - acc = [] - with torch.no_grad(): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits_3[3:-1, :, :] + ) + if self.config.sequence_parallel: + eagle_module_input_hidden_states = scatter_to_sequence_parallel_region( + eagle_module_input_hidden_states ) - eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - eagle_top1 += self.eagle_module.d2t[eagle_top1] - top1_p = torch.eq(labels[:, 4:], eagle_top1).sum() / eagle_top1.numel() - acc.append(top1_p) - - if get_tensor_model_parallel_rank() == 0: - print( - f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3} EAGLE 4th Top-1: {acc}", - flush=True, + + # If labels are not provided, return the original logits. We only return after + # all eagle weights have been exercised for quantization calibration purpose. + if labels is None: + return logits_sbh.transpose(0, 1).contiguous() + + # TTT loss computation + eagle_logit = eagle_logits[0] + loss_ = self._compute_eagle_loss(logits_sbh, labels, eagle_logit) + loss_ = loss_[:, ttt_step:] + loss[:, ttt_step + 1 :] += self.eagle_loss_decay_factor**ttt_step * loss_ + # Diffusion loss computation + if self.eagle_config.parallel_draft_step > 1: + eagle_logit = eagle_logits[1] + # Diffusion labels = input_ids, so we do not shift labels here. + # Diffusion loss only compute for the masked tokens + loss_ = self._compute_eagle_loss( + logits_sbh, labels, eagle_logit, shift_labels=False ) + loss_ = loss_ * eagle_inputs["diffusion_loss_mask"][:, :-1] + shift_idx = (ttt_step // eagle_inputs["block_len"] + 1) * eagle_inputs["block_len"] + loss_ = loss_[:, shift_idx:] + loss[:, shift_idx + 1 :] += self.eagle_loss_decay_factor**ttt_step * loss_ + + if self.eagle_report_acc and not self.training: + with torch.no_grad(): + eagle_logit = eagle_logits[0] + gathered_logits = gather_from_tensor_model_parallel_region(eagle_logit) + gathered_logits = gathered_logits[ttt_step:-1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = ( + torch.eq(labels[:, ttt_step + 1 :], eagle_top1).sum() / eagle_top1.numel() + ) + acc.append(top1_p) + if self.eagle_config.parallel_draft_step > 1: + # Diffusion accuracy only for masked tokens + eagle_logit = eagle_logits[1] + gathered_logits = gather_from_tensor_model_parallel_region(eagle_logit) + gathered_logits = gathered_logits[shift_idx:-1] + eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + eagle_top1 += self.eagle_module.d2t[eagle_top1] + top1_p = ( + torch.eq(labels[:, shift_idx:-1], eagle_top1) + * eagle_inputs["diffusion_loss_mask"][:, shift_idx:-1] + ).sum() / eagle_inputs["diffusion_loss_mask"][:, shift_idx:-1].sum() + acc.append(top1_p) + + if self.eagle_report_acc and not self.training and get_tensor_model_parallel_rank() == 0: + print( + f"{torch.distributed.get_rank():3}/{torch.distributed.get_world_size():3}" + f"EAGLE Top-1: {acc}", + flush=True, + ) return loss @@ -1627,9 +1478,12 @@ def pseudo_speculative_generate( self, input_ids: torch.Tensor, steps: int = 1, + threshold: float = 0.5, ): """Pseudo generate of the EAGLE GPTModel. + This function does not support kv cache as sequence parallel may be enabled. + Returns: base_token (torch.Tensor): token from base model draft_tokens (torch.Tensor): draft tokens from eagle module @@ -1679,37 +1533,127 @@ def pseudo_speculative_generate( draft_tokens = [] for _ in range(steps): if self.eagle_config.parallel_draft_step > 1: - for i in range(self.eagle_config.parallel_draft_step - 1): - eagle_ids = torch.cat( - (eagle_ids, getattr(self, f"mask_token_{i}").view((1, 1))), dim=-1 - ) - hidden_states = torch.cat((hidden_states, hidden_states[-1:]), dim=0) + # Pad mask_token and dummy hidden_states for parallel draft. + # hidden_states will be replaced with mask token embeddings + # after padding + diffusion_tokens = self.mask_token.repeat( + eagle_ids.shape[0], self.eagle_config.parallel_draft_step - 1 + ).to(eagle_ids.device) + eagle_ids = torch.cat( + ( + eagle_ids, + diffusion_tokens, + ), + dim=-1, + ) + hidden_states = torch.cat( + (hidden_states, hidden_states[-self.eagle_config.parallel_draft_step + 1 :]), + dim=0, + ) padded_eagle_ids, seq_len, padded_hidden_states = right_padding( eagle_ids, hidden_states ) - if self.config.sequence_parallel: - padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states) eagle_attention_mask, eagle_position_ids = get_default_attention_mask_and_position_ids( padded_eagle_ids ) eagle_inputs = {} eagle_inputs["input_ids"] = padded_eagle_ids - eagle_inputs["embedding"] = self.embedding( - input_ids=padded_eagle_ids, - position_ids=eagle_position_ids, - ) + if self.eagle_config.parallel_draft_step > 1: + embeddings = self.eagle_module.embedding( + input_ids=padded_eagle_ids, + position_ids=eagle_position_ids, + ) + else: + embeddings = self.embedding( + input_ids=padded_eagle_ids, + position_ids=eagle_position_ids, + ) + if self.config.sequence_parallel: + gathered_embedding = gather_from_sequence_parallel_region(embeddings) + else: + gathered_embedding = embeddings + if self.eagle_config.parallel_draft_step > 1: + # Replace dummy hidden_states with embeddings of mask_token + padded_hidden_states[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] = gathered_embedding[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] + if self.config.sequence_parallel: + padded_hidden_states = scatter_to_sequence_parallel_region(padded_hidden_states) + embeddings = scatter_to_sequence_parallel_region(gathered_embedding) + eagle_inputs["embedding"] = embeddings eagle_inputs["hidden_states"] = padded_hidden_states + eagle_inputs["attention_mask"] = eagle_attention_mask + # Adjust attention mask for diffusion + if self.eagle_config.parallel_draft_step > 1: + # Bidirectional attention for the diffusion tokens + eagle_inputs["attention_mask"][ + :, + :, + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len :, + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len :, + ] = False # [TODO] (chenhany): let the module compute itself eagle_inputs["rotary_pos_emb"] = None - _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( - eagle_inputs, - output_weight, - ) + diffusion_step = 0 + expected_num_tokens = self.eagle_config.parallel_draft_step + first_iteration = True + while expected_num_tokens > 0: + diffusion_step += 1 + _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( + eagle_inputs, + output_weight, + ) + + if first_iteration: + first_iteration = False + expected_num_tokens -= 1 + if self.eagle_config.parallel_draft_step > 1: + diffusion_logits = gather_from_tensor_model_parallel_region(eagle_logits)[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len, + :, + :, + ] + diffusion_tokens, num_accepted_tokens = self._accept_diffusion_tokens( + diffusion_tokens, + diffusion_logits, + threshold, + ) + assert num_accepted_tokens > 0, ( + "At least one token should be accepted in diffusion." + ) + expected_num_tokens -= num_accepted_tokens + # Update embeddings and hidden_states for next diffusion step + padded_eagle_ids[ + :, seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] = diffusion_tokens + embeddings = self.eagle_module.embedding( + input_ids=padded_eagle_ids, + position_ids=eagle_position_ids, + ) + if self.config.sequence_parallel: + gathered_embedding = gather_from_sequence_parallel_region(embeddings) + else: + gathered_embedding = embeddings + # Replace diffusion hidden_states with embeddings + padded_hidden_states[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] = gathered_embedding[ + seq_len - self.eagle_config.parallel_draft_step + 1 : seq_len + ] + if self.config.sequence_parallel: + padded_hidden_states = scatter_to_sequence_parallel_region( + padded_hidden_states + ) + embeddings = scatter_to_sequence_parallel_region(gathered_embedding) + eagle_inputs["embedding"] = embeddings + eagle_inputs["hidden_states"] = padded_hidden_states eagle_logits = eagle_logits[:seq_len, :, :] if self.config.sequence_parallel: eagle_next_hidden_states_input = gather_from_sequence_parallel_region( @@ -1720,33 +1664,91 @@ def pseudo_speculative_generate( if self.eagle_config.parallel_draft_step > 1: draft_token = ( gather_from_tensor_model_parallel_region(eagle_logits)[ - -self.eagle_config.parallel_draft_step :, :, : + -self.eagle_config.parallel_draft_step : -self.eagle_config.parallel_draft_step + + 1, + :, + :, ] .argmax(dim=-1) .transpose(0, 1) ) else: draft_token = ( - gather_from_tensor_model_parallel_region(eagle_logits)[-1:, :, :] + gather_from_tensor_model_parallel_region(eagle_logits)[ + -1:, + :, + :, + ] .argmax(dim=-1) .transpose(0, 1) ) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: draft_token += self.eagle_module.d2t[draft_token] - if self.eagle_config.parallel_draft_step > 1: - return base_token, draft_token - draft_tokens.append(draft_token) + # Remove mask tokens from eagle_ids before adding draft_token + # Remove added hidden_states + if self.eagle_config.parallel_draft_step > 1: + eagle_ids = eagle_ids[:, : -self.eagle_config.parallel_draft_step + 1] + hidden_states = hidden_states[: -self.eagle_config.parallel_draft_step + 1] + eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1) - hidden_states = torch.cat( - (hidden_states, eagle_next_hidden_states_input[-1:, :, :]), dim=0 - ) + if self.eagle_config.parallel_draft_step > 1: + hidden_states = torch.cat( + ( + hidden_states, + eagle_next_hidden_states_input[ + -self.eagle_config.parallel_draft_step : -self.eagle_config.parallel_draft_step + + 1, + :, + :, + ], + ), + dim=0, + ) + else: + hidden_states = torch.cat( + (hidden_states, eagle_next_hidden_states_input[-1:, :, :]), dim=0 + ) draft_tokens = torch.cat(draft_tokens, dim=-1) + if self.eagle_config.parallel_draft_step > 1: + # Add difussion tokens to draft_tokens + draft_tokens = torch.cat( + ( + draft_tokens, + diffusion_tokens, + ), + dim=-1, + ) + + return base_token, draft_tokens, diffusion_step - return base_token, draft_tokens + def _accept_diffusion_tokens(self, diffusion_tokens, diffusion_logits, threshold): + assert diffusion_logits.shape[1] == 1, ( + "Batch size > 1 is not supported in _accept_diffusion_tokens." + ) + p = torch.softmax(diffusion_logits, dim=-1) + entropy = -(p * torch.log(p)).sum(dim=-1) + # Only consider masked tokens + # Manually change entropy to a large value to reject unmasked tokens + entropy = entropy * (diffusion_tokens.transpose(0, 1) == self.mask_token).float() + ( + 1e10 * (diffusion_tokens.transpose(0, 1) != self.mask_token).float() + ) + sorted_values, indices = torch.sort(entropy, dim=0, descending=False) + cumulative_values = torch.cumsum(sorted_values, dim=0) + mask = (cumulative_values > threshold).int() + # First index where cumulative_values > threshold + idx = torch.argmax(mask, dim=0) + # Ensure at least one token is accepted + idx = torch.where(idx == 0, torch.tensor(1, device=idx.device), idx) + accepted_tokens_idx = indices[:idx] + diffusion_tokens[:, accepted_tokens_idx.squeeze(1)] = ( + diffusion_logits[accepted_tokens_idx.squeeze(1)].argmax(dim=-1).transpose(0, 1) + ) + return diffusion_tokens, idx.item() class MegatronARValidation(AcceptanceRateValidation): @@ -1756,7 +1758,7 @@ def get_ground_truth(self, input_ids, osl): """This function returns ground truth output tokens from the base model.""" input_ids = copy.deepcopy(input_ids) for _ in range(osl): - input_id, _ = self.model.pseudo_speculative_generate(input_ids, steps=0) + input_id, _, _ = self.model.pseudo_speculative_generate(input_ids, steps=0) input_ids = torch.cat((input_ids, input_id), dim=-1) if input_id[0, 0] == self.end_token: break diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 648cc8163..0e97cedab 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -290,11 +290,13 @@ def check_draft(self, ground_truth, input_ids, draft_tokens): return input_ids - def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=True): + def check_data_consistency_across_ranks(self, data, group=None, fail_when_mismatch=False): """This function checks the data consistency across all ranks in the group. Use rank 0 data as the golden set to broadcast to all ranks. - Each rank will then compare to this data and through error if different. + Each rank compares its data against this golden set and either raises + (when fail_when_mismatch=True) or emits a warning while forcing every + rank to adopt rank 0's data. """ if not torch.distributed.is_initialized(): return data @@ -338,6 +340,7 @@ def validate( if tree_paths: tree = Tree(tree_paths) + diffusion_steps = 0 while input_ids.shape[1] < ground_truth.shape[1]: cnt += 1 input_ids = self.check_draft(ground_truth, input_ids, draft_tokens) @@ -346,23 +349,21 @@ def validate( if tree_paths: input_id, draft_tokens, pred_tokens = self.model.tree_decode(input_ids, tree=tree) - pred_tokens = self.check_data_consistency_across_ranks( - pred_tokens, fail_when_mismatch=False - ) + pred_tokens = self.check_data_consistency_across_ranks(pred_tokens) else: - input_id, draft_tokens = self.model.pseudo_speculative_generate( + input_id, draft_tokens, diffusion_step = self.model.pseudo_speculative_generate( input_ids, steps=steps ) - draft_tokens = self.check_data_consistency_across_ranks( - draft_tokens, fail_when_mismatch=False - ) + draft_tokens = self.check_data_consistency_across_ranks(draft_tokens) + diffusion_steps += diffusion_step input_id = self.check_data_consistency_across_ranks(input_id) input_ids = torch.cat((input_ids, input_id), dim=-1) ar = (ground_truth.shape[1] - isl) / cnt + diffusion_steps /= cnt - return ground_truth, ar + return ground_truth, ar, diffusion_steps @contextlib.contextmanager