diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 4870d08b5..cdd1b8ffc 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]: def make_eagle_supervised_data_module( - tokenizer: transformers.PreTrainedTokenizer, data_args, use_offline_training: bool + tokenizer: transformers.PreTrainedTokenizer, + data_args, + use_offline_training: bool, + max_length=None, ) -> dict: """Make dataset and collator for supervised fine-tuning. @@ -295,7 +298,7 @@ def make_eagle_supervised_data_module( train_dataset = dataset_cls(valid_entries[:num_train], tokenizer=tokenizer) eval_dataset = dataset_cls(valid_entries[num_train:], tokenizer=tokenizer) - data_collator = DataCollatorForOffline() + data_collator = DataCollatorForOffline(max_length=max_length) else: print_rank_0("Loading input conversations...") dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset @@ -303,7 +306,7 @@ def make_eagle_supervised_data_module( train_dataset = dataset_cls(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer) eval_dataset = dataset_cls(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer) - data_collator = DataCollatorWithPadding() + data_collator = DataCollatorWithPadding(max_length=max_length) return { "train_dataset": train_dataset, @@ -313,6 +316,9 @@ def make_eagle_supervised_data_module( class DataCollatorWithPadding: + def __init__(self, max_length): + self.max_length = max_length + def paddingtensor2d(self, intensors, length): n, dim = intensors.shape padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) @@ -325,19 +331,18 @@ def paddingtensor(self, intensors, length): return outtensors def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - max_length = max(item["input_ids"].shape[0] for item in features) batch_input_ids = torch.stack( - [self.paddingtensor(item["input_ids"], max_length) for item in features] + [self.paddingtensor(item["input_ids"], self.max_length) for item in features] ) batch_attention_mask = torch.stack( - [self.paddingtensor(item["attention_mask"], max_length) for item in features] + [self.paddingtensor(item["attention_mask"], self.max_length) for item in features] ) batch_loss_mask = torch.stack( - [self.paddingtensor(item["loss_mask"], max_length) for item in features] + [self.paddingtensor(item["loss_mask"], self.max_length) for item in features] ) batch_labels = torch.stack( - [self.paddingtensor(item["labels"], max_length) for item in features] + [self.paddingtensor(item["labels"], self.max_length) for item in features] ) batch = { @@ -357,16 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: raise ValueError("No kwargs found in batch features. Offline data required.") features = [item["kwargs"]["base_model_outputs"] for item in features] - max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features) batch_hidden_states = torch.stack( [ - self.paddingtensor2d(item["base_model_hidden_states"], max_hs_length) + self.paddingtensor2d(item["base_model_hidden_states"], self.max_length) for item in features ] ) batch_aux_hidden_states = torch.stack( - [self.paddingtensor2d(item["aux_hidden_states"], max_hs_length) for item in features] + [self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features] ) batch = { diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index ea711765b..d1373363e 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -227,7 +227,9 @@ def train(): if training_args.mode == "medusa": data_module = make_medusa_supervised_data_module(tokenizer, data_args) elif training_args.mode in ["eagle1", "eagle3"]: - data_module = make_eagle_supervised_data_module(tokenizer, data_args, use_offline_training) + data_module = make_eagle_supervised_data_module( + tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len + ) class ARValidationCallback(TrainerCallback): def __init__(self, ar_validate_steps: int = 500): diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index af617d83f..0fc6fb11b 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -36,6 +36,7 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss +from torch.nn.attention.flex_attention import BlockMask, create_block_mask from transformers import Cache, DynamicCache, PretrainedConfig, PreTrainedModel from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -182,6 +183,10 @@ def __init__(self, config, decoder_layer_cls, bias=False): """Init function for EagleModule.""" super().__init__() self.config = config + + # Use flex attention for efficient TTT + config._attn_implementation = "flex_attention" + self.layers = nn.ModuleList( [decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) @@ -311,7 +316,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_values=past_key_values, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -446,6 +451,17 @@ def modify( if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) + self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. + self._cached_attn_blk_masks = [] + + def _get_ttt_attention_mask(self, seq_length, ttt_step): + # compile and cached flex attention masks in first call + if ttt_step >= len(self._cached_attn_blk_masks): + self._cached_attn_blk_masks.append(self._compile_ttt_block_mask(seq_length, ttt_step)) + + # return cached flex attention mask + return self._cached_attn_blk_masks[ttt_step] + def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length ): @@ -523,167 +539,40 @@ def _get_eagle_module_inputs( return eagle_input_ids, attention_mask, position_ids - def _concat_eagle_inputs( - self, - input_ids_0, - eagle_input_hidden_states_0, - attention_mask_0, - position_ids_0, - eagle_generated_hs, - ): - """Helper function to prepare eagle inputs for second-fourth eagle forward pass during training-time-testing. - - This is a slow version, focusing on the correctness only. TODO: optimize this. - Parameters: - input_ids_0: [b, seq_length], input_ids from the 0th eagle step - base_model_hidden_states: [b, seq_length, h] - eagle_input_hidden_states_0: [b, seq_length, h] - attention_mask_0: [b, seq_length, seq_length], from the 0th eagle step. - position_ids_0: [b, seq_length], from the 0th eagle step. - eagle_generated_hs: [b, seq_length * n_steps, h], from the LAST eagle step. - """ - b, seq_length, h = eagle_input_hidden_states_0.shape - dtypemin = torch.finfo(attention_mask_0.dtype).min - - if eagle_generated_hs.shape[1] == seq_length: - # This is the second step of eagle forward - - # Concat input_ids - cat_input_ids = torch.cat((input_ids_0, input_ids_0), dim=-1) - - # Concat hidden_states - cat_eagle_input_hidden_states = torch.cat( - ( - eagle_input_hidden_states_0, - torch.zeros( - (b, 1, h), - dtype=eagle_input_hidden_states_0.dtype, - device=eagle_input_hidden_states_0.device, - ), - eagle_generated_hs[:, :-1, :], - ), - dim=1, - ) + def _compile_ttt_block_mask(self, seq_length, ttt_step) -> BlockMask: + """Compile TTT attention_masks with symbolic masks and return a BlockMask object for flex attention.""" + if ttt_step == 0: - # Expand attn_mask - zero_mask = torch.ones_like(attention_mask_0).bool() - mask_2_1 = attention_mask_0.clone().detach() - mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attention_mask_0).bool() - for i in range(1, seq_length - 1): - mask_2_2[:, :, i, i] = False - cat_attention_mask = torch.cat( - ( - torch.cat((attention_mask_0, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2), dim=-1), - ), - dim=-2, - ) - cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) - - # Concat position_ids - cat_position_ids = torch.cat((position_ids_0, position_ids_0), dim=-1) - - elif eagle_generated_hs.shape[1] == seq_length * 2: - cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0), dim=-1) - cat_eagle_input_hidden_states = torch.cat( - ( - eagle_input_hidden_states_0, - torch.zeros( - (b, 1, h), - dtype=eagle_input_hidden_states_0.dtype, - device=eagle_input_hidden_states_0.device, - ), - eagle_generated_hs[:, :-1, :], - ), - dim=1, - ) - zero_mask = torch.ones_like(attention_mask_0).bool() - mask_2_1 = attention_mask_0.clone().detach() - mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attention_mask_0).bool() - for i in range(1, seq_length - 1): - mask_2_2[:, :, i, i] = False - - mask_3_1 = mask_2_1.clone().detach() - mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] - mask_3_2 = mask_2_2.clone().detach() - mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True - mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True - cat_attention_mask = torch.cat( - ( - torch.cat((attention_mask_0, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3), dim=-1), - ), - dim=-2, - ) + def msk(b, h, q_idx, kv_idx): + # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0 + return (kv_idx <= (q_idx - 1)) | (kv_idx == q_idx + seq_length) - cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) - cat_position_ids = torch.cat((position_ids_0, position_ids_0, position_ids_0), dim=-1) - - elif eagle_generated_hs.shape[1] == seq_length * 3: - cat_input_ids = torch.cat((input_ids_0, input_ids_0, input_ids_0, input_ids_0), dim=-1) - cat_eagle_input_hidden_states = torch.cat( - ( - eagle_input_hidden_states_0, - torch.zeros( - (b, 1, h), - dtype=eagle_input_hidden_states_0.dtype, - device=eagle_input_hidden_states_0.device, - ), - eagle_generated_hs[:, :-1, :], - ), - dim=1, - ) - zero_mask = torch.ones_like(attention_mask_0).bool() - mask_2_1 = attention_mask_0.clone().detach() - mask_2_1[:, :, :, :-1] = mask_2_1[:, :, :, 1:] - mask_2_2 = torch.ones_like(attention_mask_0).bool() - for i in range(1, seq_length - 1): - mask_2_2[:, :, i, i] = False - - mask_3_1 = mask_2_1.clone().detach() - mask_3_1[:, :, :, :-1] = mask_3_1[:, :, :, 1:] - mask_3_2 = mask_2_2.clone().detach() - mask_3_2[:, :, :, :-1] = mask_3_2[:, :, :, 1:] - mask_3_2[:, :, 1, 0] = True - mask_3_3 = mask_2_2.clone().detach() - mask_3_3[:, :, 1, 1] = True - - mask_4_1 = mask_3_1.clone().detach() - mask_4_1[:, :, :, :-1] = mask_4_1[:, :, :, 1:] - mask_4_2 = mask_3_2.clone().detach() - mask_4_2[:, :, :, :-1] = mask_4_2[:, :, :, 1:] - mask_4_2[:, :, 2, 0] = True - mask_4_3 = mask_3_3.clone().detach() - mask_4_3[:, :, :, :-1] = mask_4_3[:, :, :, 1:] - mask_4_3[:, :, 2, 1] = True - mask_4_4 = mask_3_3.clone().detach() - mask_4_4[:, :, 2, 2] = True - - cat_attention_mask = torch.cat( - ( - torch.cat((attention_mask_0, zero_mask, zero_mask, zero_mask), dim=-1), - torch.cat((mask_2_1, mask_2_2, zero_mask, zero_mask), dim=-1), - torch.cat((mask_3_1, mask_3_2, mask_3_3, zero_mask), dim=-1), - torch.cat((mask_4_1, mask_4_2, mask_4_3, mask_4_4), dim=-1), - ), - dim=-2, - ) - cat_attention_mask = cat_attention_mask.masked_fill(cat_attention_mask == 1, dtypemin) - cat_position_ids = torch.cat( - (position_ids_0, position_ids_0, position_ids_0, position_ids_0), dim=-1 - ) + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 2) + elif ttt_step == 1: - else: - raise ValueError( - f"EAGLE generated hidden states shape {eagle_generated_hs.shape} is not supported" - ) + def msk(b, h, q_idx, kv_idx): + # attention mask of shape [seq_len, 3* seq_len] for TTT step 1 + return ( + (kv_idx <= (q_idx - 2)) + | ((kv_idx == q_idx + seq_length - 1) & (kv_idx >= seq_length)) + | ((kv_idx == q_idx + 2 * seq_length) & (kv_idx >= seq_length * 2)) + ) + + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 3) + elif ttt_step == 2: - return cat_eagle_input_hidden_states, cat_input_ids, cat_attention_mask, cat_position_ids + def msk(b, h, q_idx, kv_idx): + # attention mask of shape [seq_len, 4* seq_len] for TTT step 2 + return ( + (kv_idx <= (q_idx - 3)) + | ((kv_idx == q_idx + seq_length - 2) & (kv_idx >= seq_length)) + | ((kv_idx == q_idx + 2 * seq_length - 1) & (kv_idx >= seq_length * 2)) + | ((kv_idx == q_idx + 3 * seq_length) & (kv_idx >= seq_length * 3)) + ) + + return create_block_mask(msk, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length * 4) + else: + raise ValueError(f"EAGLE TTT step {ttt_step} is not supported") def _base_model_forward( self, @@ -739,6 +628,7 @@ def _eagle_forward( attention_mask, position_ids, position_embeddings, + eagle_cache=None, ): eagle_postnorm_h, eagle_prenorm_h, eagle_cache = self.eagle_module( eagle_input_hidden_states, @@ -747,6 +637,7 @@ def _eagle_forward( position_ids=position_ids, use_cache=True, position_embeddings=position_embeddings, + past_key_values=eagle_cache, ) eagle_lm_head = ( self.eagle_module.eagle_lm_head @@ -771,8 +662,6 @@ def forward( cache_position: torch.LongTensor | None = None, logits_to_keep: int = 0, loss_mask: torch.Tensor | None = None, - classification_loss_coefficient: float | None = 1, - regression_loss_coefficient: float | None = 0, **kwargs, ) -> Any: """Forward pass of the EagleModel. @@ -824,12 +713,15 @@ def forward( if not isinstance(past_key_values, Cache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if not isinstance(eagle_cache, Cache): + eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) # ====Run eagle forward==== eagle_loss = None + train_accs = [] if self.training: # In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers - batch_size, seq_length, _ = base_model_hidden_states.shape + b, seq_length, h = base_model_hidden_states.shape if self.eagle_config.use_aux_hidden_state: if "base_model_outputs" in kwargs: aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"] @@ -852,168 +744,66 @@ def forward( position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids) # Then, we run eagle forward - eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( + _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hidden_states, inputs_embeds, attention_mask_0, position_ids, position_embeddings, + eagle_cache, ) - if not isinstance(eagle_cache, Cache): - eagle_cache = DynamicCache.from_legacy_cache(eagle_cache) - past_key_values.eagle_cache = eagle_cache # Compute loss on the eagle modules - regression_loss, classification_loss, accuracy_0 = self._eagle_loss( - base_model_hidden_states[:, 1:], + classification_loss, acc = self._eagle_loss( base_model_logits[:, 1:], - eagle_postnorm_h[:, :-1], eagle_logits[:, :-1], loss_mask[:, 1:], ) - eagle_loss = ( - regression_loss_coefficient * regression_loss - + classification_loss_coefficient * classification_loss - ) + eagle_loss = classification_loss + train_accs.append(acc) # ====Perform training-time-testing with 3 extra eagle forward passes==== - # ====Second step of eagle forward==== - eagle_input_hidden_states_1, eagle_input_ids_1, attention_mask_1, position_ids_1 = ( - self._concat_eagle_inputs( - eagle_input_ids, - eagle_input_hidden_states, - attention_mask_0, - position_ids, - eagle_prenorm_h, - ) - ) - with torch.no_grad(): - inputs_embeds = self.model.embed_tokens(eagle_input_ids_1) - position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states_1, position_ids_1) - eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( - eagle_input_hidden_states_1, - inputs_embeds, - attention_mask_1, - position_ids_1, - position_embeddings, - ) - - regression_loss, classification_loss, accuracy_1 = self._eagle_loss( - # base model predict +1 tok, while eagle predict +2 - # so we shift base model outputs compared to eagle outputs - base_model_hidden_states[:, 1:], - base_model_logits[:, 1:], - eagle_postnorm_h[ - :, - -seq_length:-1, - ], - eagle_logits[ - :, - -seq_length:-1, - ], - # additionally, we mask the first n tok of eagle outputs at nth TTT step - torch.cat( + for ttt_step in range(self.num_ttt_steps): + eagle_input_hidden_states = torch.cat( ( - torch.zeros(batch_size, 1, dtype=loss_mask.dtype, device=loss_mask.device), - loss_mask[:, 2:], + torch.zeros( + (b, 1, h), + dtype=eagle_input_hidden_states.dtype, + device=eagle_input_hidden_states.device, + ), + eagle_prenorm_h[:, :-1, :], ), dim=1, - ), - ) - eagle_loss += ( - regression_loss_coefficient * regression_loss - + classification_loss_coefficient * classification_loss - ) - - # ====Third step of eagle forward==== - eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = ( - self._concat_eagle_inputs( - eagle_input_ids, - eagle_input_hidden_states, - attention_mask_0, - position_ids, - eagle_prenorm_h, ) - ) - with torch.no_grad(): - inputs_embeds = self.model.embed_tokens(eagle_input_ids_2) - position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states_2, position_ids_2) - eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( - eagle_input_hidden_states_2, - inputs_embeds, - attention_mask_2, - position_ids_2, - position_embeddings, - ) - - regression_loss, classification_loss, accuracy_2 = self._eagle_loss( - base_model_hidden_states[:, 1:], - base_model_logits[:, 1:], - eagle_postnorm_h[:, -seq_length:-1, :], - eagle_logits[ - :, - -seq_length:-1, - ], - torch.cat( - ( - torch.zeros(batch_size, 2, dtype=loss_mask.dtype, device=loss_mask.device), - loss_mask[:, 3:], - ), - dim=1, - ), - ) - eagle_loss += ( - regression_loss_coefficient * regression_loss - + classification_loss_coefficient * classification_loss - ) - - # ====Fourth step of eagle forward==== - eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = ( - self._concat_eagle_inputs( - eagle_input_ids, + attention_mask = self._get_ttt_attention_mask(seq_length, ttt_step) + _, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward( eagle_input_hidden_states, - attention_mask_0, + inputs_embeds, + attention_mask, position_ids, - eagle_prenorm_h, + position_embeddings, + eagle_cache, ) - ) - with torch.no_grad(): - inputs_embeds = self.model.embed_tokens(eagle_input_ids_3) - position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states_3, position_ids_3) - eagle_postnorm_h, _, eagle_logits, eagle_cache = self._eagle_forward( - eagle_input_hidden_states_3, - inputs_embeds, - attention_mask_3, - position_ids_3, - position_embeddings, - ) - - regression_loss, classification_loss, accuracy_3 = self._eagle_loss( - base_model_hidden_states[:, 1:], - base_model_logits[:, 1:], - eagle_postnorm_h[ - :, - -seq_length:-1, - ], - eagle_logits[ - :, - -seq_length:-1, - ], - torch.cat( - ( - torch.zeros(batch_size, 3, dtype=loss_mask.dtype, device=loss_mask.device), - loss_mask[:, 4:], + classification_loss, acc = self._eagle_loss( + # base model predict +1 tok, while eagle predict +2 + # so we shift base model outputs compared to eagle outputs + base_model_logits[:, 1:], + eagle_logits[:, :-1], + # additionally, we mask the first n tok of eagle outputs at nth TTT step + torch.cat( + ( + torch.zeros( + b, 1 + ttt_step, dtype=loss_mask.dtype, device=loss_mask.device + ), + loss_mask[:, 2 + ttt_step :], + ), + dim=1, ), - dim=1, - ), - ) - eagle_loss += ( - regression_loss_coefficient * regression_loss - + classification_loss_coefficient * classification_loss - ) - + ) + eagle_loss += classification_loss + train_accs.append(acc) # Finally, we merge base model loss and eagle loss, raise error if both are None if base_model_loss is not None and eagle_loss is not None: loss = base_model_loss + eagle_loss @@ -1027,37 +817,28 @@ def forward( "Both base_model_loss and eagle_loss are skipped. At least one loss must be computed." ) - train_acc = (accuracy_0, accuracy_1, accuracy_2, accuracy_3) if self.training else None - return ModelOutput( loss=loss, logits=base_model_logits, past_key_values=past_key_values, hidden_states=base_model_hidden_states, - train_acc=train_acc, + train_acc=train_accs, ) def _eagle_loss( self, - base_model_hidden_states, base_model_logits, - eagle_hidden_states, eagle_logits, loss_mask, ): """Function for EAGLE loss computing.""" loss_mask = loss_mask[:, :, None] - criterion = nn.SmoothL1Loss(reduction="none") classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)( eagle_logits ) classification_loss = -torch.sum(torch.sum(loss_mask * classification_loss, 2)) / ( loss_mask.sum() + 1e-5 ) - regression_loss = criterion(eagle_hidden_states, base_model_hidden_states) - regression_loss = torch.sum(torch.mean(loss_mask * regression_loss, 2)) / ( - loss_mask.sum() + 1e-5 - ) # Compute accuracy base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1) eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1) @@ -1066,7 +847,7 @@ def _eagle_loss( denom = valid.sum().clamp_min(1).float() accuracy = round(correct.sum().float().div(denom).item(), 3) - return regression_loss, classification_loss, accuracy + return classification_loss, accuracy @torch.no_grad() def pseudo_speculative_generate( diff --git a/tests/unit/torch/speculative/plugins/test_hf_speculative.py b/tests/unit/torch/speculative/plugins/test_hf_speculative.py index 51d93996e..84b00be6d 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_speculative.py +++ b/tests/unit/torch/speculative/plugins/test_hf_speculative.py @@ -17,7 +17,6 @@ from copy import deepcopy import pytest -import torch from _test_utils.torch_model.transformers_models import ( create_tiny_llama_dir, get_tiny_llama, @@ -69,122 +68,3 @@ def test_eagle_model_convert_save_and_restore(tmp_path, eagle_config): model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model") assert isinstance(model_test, mtsp.plugins.HFEagleModel) tf_modelopt_state_and_output_tester(model_ref, model_test) - - -# fmt: off -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -def test_eagle_model_prepare_eagle_inputs(dtype): - dummy_model = get_tiny_llama(num_hidden_layers=4) - - config = EAGLE3_DEFAULT_CFG["config"] - config["eagle_architecture_config"].update({ - "draft_vocab_size": dummy_model.config.vocab_size, - "hidden_size": dummy_model.config.hidden_size, - }) - mtsp.convert(dummy_model, mode=[("eagle", config)]) - - eagle_input_ids_0 = torch.tensor([[10, 20, 30, 40]], dtype=torch.long) - position_ids_0 = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) - - - #This is concatenated from 3 intermediate base model layers - cat_aux_hidden_states = torch.randn(1, 4, 32, dtype=dtype) - - #This is eagle output from previous eagle forward pass - dummy_eagle_output_hidden_states = torch.randn(1, 4, 32, dtype=dtype) - - #This is the causal mask for the 0th eagle step - m = torch.finfo(dtype).min - attention_mask_0 = torch.tensor([[0, m, m, m], # input tok 10-> predicting token 20 - [0, 0, m, m], # 20 -> 30 - [0, 0, 0, m], # 30 -> 40 - [0, 0, 0, 0]] # 40 -> tok after 40 - - , dtype=dtype).view(1, 1, 4, 4) - - # 2nd eagle step - eagle_input_h_1, eagle_input_ids_1, attention_mask_1, position_ids_1 = dummy_model._concat_eagle_inputs( - eagle_input_ids_0, - cat_aux_hidden_states, - attention_mask_0, - position_ids_0, - dummy_eagle_output_hidden_states, - ) - - assert eagle_input_ids_1.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) - assert position_ids_1.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) - - assert attention_mask_1.equal(torch.tensor([[0, m, m, m, m, m, m, m], # (x) output discarded - [0, 0, m, m, m, m, m, m], # (x) - [0, 0, 0, m, m, m, m, m], # (x) - [0, 0, 0, 0, m, m, m, m], # (x) - - [m, m, m, m, m, m, m, m], # (x) input tok 10-> predicting token 20 - [0, m, m, m, m, 0, m, m], # 20 -> 30 - [0, 0, m, m, m, m, 0, m], # 30 -> 40 - [0, 0, 0, 0, m, m, m, m], # (x) 40 -> tok after 40 - ], dtype=dtype).view(1, 1, 8, 8)) - - # 3rd eagle step - eagle_input_hidden_states_2, eagle_input_ids_2, attention_mask_2, position_ids_2 = dummy_model._concat_eagle_inputs( - eagle_input_ids_0, - cat_aux_hidden_states, - attention_mask_0, - position_ids_0, - torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states], dim=1), - ) - assert eagle_input_ids_2.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) - assert position_ids_2.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) - - assert attention_mask_2.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, m, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, 0, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) - - [m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, m, m, m, m, 0, m, m, m, m, m, m], # (x) - [0, 0, m, m, m, m, 0, m, m, m, m, m], # (x) - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) - - [m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20 - [m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30 - [0, m, m, m, m, 0, m, m, m, m, 0, m], # 30 -> 40 - [0, 0, 0, 0, m, m, m, m, m, m, m, m], # (x) 40 -> tok after 40 - - ], dtype=dtype).view(1, 1, 12, 12)) - - # 4th eagle step - eagle_input_hidden_states_3, eagle_input_ids_3, attention_mask_3, position_ids_3 = dummy_model._concat_eagle_inputs( - eagle_input_ids_0, - cat_aux_hidden_states, - attention_mask_0, - position_ids_0, - torch.cat([dummy_eagle_output_hidden_states, dummy_eagle_output_hidden_states, - dummy_eagle_output_hidden_states],dim=1), - ) - - assert eagle_input_ids_3.equal(torch.tensor([[10, 20, 30, 40, 10, 20, 30, 40, - 10, 20, 30, 40, 10, 20, 30, 40]], dtype=torch.long)) - assert position_ids_3.equal(torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]], dtype=torch.long)) - - assert attention_mask_3.equal(torch.tensor([[0, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, m, m, m, m, 0, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, m, m, m, m, 0, m, m, m, m, 0, m, m, m, m, m], # (x) - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)10 -> 20 - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x)20 -> 30 - [m, m, m, m, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - [0, 0, 0, 0, m, m, m, m, m, m, m, m, m, m, m, m], # (x) - - ], dtype=dtype).view(1, 1, 16, 16))