From ae2aa934b8088f49f3d4763ac706945d49a6ffbe Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Thu, 25 Sep 2025 23:28:30 +0000 Subject: [PATCH] fix: eagle3 quantized base model Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/eagle_utils.py | 31 ++++++++++++++++ examples/speculative_decoding/main.py | 37 ++----------------- .../torch/speculative/plugins/transformers.py | 30 ++++++++++----- modelopt/torch/speculative/utils.py | 14 +++++++ 4 files changed, 69 insertions(+), 43 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index cdd1b8ffc..452bb4ded 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -19,11 +19,21 @@ import torch import transformers +from ar_validate import validate_ar +from datasets import load_dataset from torch.utils.data import Dataset +from transformers import TrainerCallback from transformers.trainer_pt_utils import LabelSmoother from modelopt.torch.utils import print_rank_0 +try: + import wandb + + wandb.init() +except ImportError: + wandb = None + IGNORE_TOKEN_ID = LabelSmoother.ignore_index REMOVE_THINK_CHAT_TEMPLATE = ( @@ -382,3 +392,24 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: } return batch + + +class ARValidationCallback(TrainerCallback): + def __init__(self, ar_validate_steps: int = 1000): + self.ar_validate_steps = ar_validate_steps + + def on_step_end(self, args, state, control, **kwargs): + if self.ar_validate_steps <= 0: + return control + if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: + print_rank_0("Running AR validation...") + ars = validate_ar( + model=kwargs["model"], + tokenizer=kwargs["processing_class"], + ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], + device=kwargs["model"].device, + ) + print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") + if wandb: + wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) + return control diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index d1373363e..20242c795 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -36,24 +36,15 @@ import torch import transformers -from ar_validate import validate_ar -from datasets import load_dataset -from eagle_utils import make_eagle_supervised_data_module +from eagle_utils import ARValidationCallback, make_eagle_supervised_data_module from medusa_utils import make_medusa_supervised_data_module -from transformers import Trainer, TrainerCallback +from transformers import Trainer from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp from modelopt.torch.utils import print_rank_0 -try: - import wandb - - wandb.init() -except ImportError: - wandb = None - torch.manual_seed(0) mto.enable_huggingface_checkpointing() @@ -147,9 +138,8 @@ def train(): model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto") tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) else: - model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, torch_dtype="auto", **model_kwargs + model_args.model_name_or_path, torch_dtype="auto", device_map="cpu" ) if use_offline_training: # When doing offline training, we need to set num_hidden_layers @@ -231,26 +221,6 @@ def train(): tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len ) - class ARValidationCallback(TrainerCallback): - def __init__(self, ar_validate_steps: int = 500): - self.ar_validate_steps = ar_validate_steps - - def on_step_end(self, args, state, control, **kwargs): - if self.ar_validate_steps <= 0: - return control - if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: - print_rank_0("Running AR validation...") - ars = validate_ar( - model=kwargs["model"], - tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], - device=kwargs["model"].device, - ) - print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb: - wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - return control - trainer = Trainer( model=model, processing_class=tokenizer, @@ -258,7 +228,6 @@ def on_step_end(self, args, state, control, **kwargs): callbacks=[ARValidationCallback(training_args.ar_validate_steps)], **data_module, ) - trainer._move_model_to_device(model, trainer.args.device) # Manually enable this to return loss in eval trainer.can_return_loss = True diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 0fc6fb11b..ad0b32074 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -52,7 +52,7 @@ from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask from ..medusa.conversion import MedusaDMRegistry from ..medusa.medusa_model import MedusaModel -from ..utils import AcceptanceRateValidation, ResBlock +from ..utils import AcceptanceRateValidation, ResBlock, temporary_set_config_value IGNORE_TOKEN_ID = LabelSmoother.ignore_index @@ -445,12 +445,20 @@ def modify( param.requires_grad = False # EAGLE-3 auxiliary hidden_states - if self.eagle_config.use_aux_hidden_state: + if (not eagle_offline) and self.eagle_config.use_aux_hidden_state: self._aux_hidden_states = [] for layer_idx, layer in enumerate(self.model.layers): if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) + # delete base model layers for offline training + if eagle_offline: + self.model._modules.pop("layers") + + # NOTE: this is a temporary hack to bypass hf trainer check: + # https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566 + self.is_quantized = False + self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later. self._cached_attn_blk_masks = [] @@ -907,13 +915,17 @@ def pseudo_speculative_generate( eagle_input_hidden_states, eagle_position_ids ) - _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( - eagle_input_hidden_states, - self.model.embed_tokens(eagle_ids), - eagle_attention_mask, - eagle_position_ids, - position_embeddings, - ) + # Use SDPA attention during generation for both stability and performance + with temporary_set_config_value( + self.eagle_module.config, "_attn_implementation", "sdpa" + ): + _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( + eagle_input_hidden_states, + self.model.embed_tokens(eagle_ids), + eagle_attention_mask, + eagle_position_ids, + position_embeddings, + ) draft_token = eagle_logits[:, -1:, :].argmax(dim=-1) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index 0e1c635a7..648cc8163 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -15,6 +15,7 @@ """Utils for speculative decoding.""" +import contextlib import copy import warnings from collections import Counter, defaultdict, deque @@ -362,3 +363,16 @@ def validate( ar = (ground_truth.shape[1] - isl) / cnt return ground_truth, ar + + +@contextlib.contextmanager +def temporary_set_config_value(config, field, value): + """Context manager to temporarily change config value.""" + if not hasattr(config, field): + raise AttributeError(f"Config does not have field '{field}'") + original_value = getattr(config, field) + try: + setattr(config, field, value) + yield + finally: + setattr(config, field, original_value)