diff --git a/docs/source/guides/5_speculative_decoding.rst b/docs/source/guides/5_speculative_decoding.rst index 001808e35..4994079fc 100644 --- a/docs/source/guides/5_speculative_decoding.rst +++ b/docs/source/guides/5_speculative_decoding.rst @@ -39,6 +39,7 @@ Example usage: import torch from transformers import AutoModelForCausalLM, AutoTokenizer import modelopt.torch.speculative as mtsp + from modelopt.torch.speculative.config import default_eagle_config, eagle3_default_config # User-defined model model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") @@ -53,8 +54,9 @@ Example usage: } elif mode == "eagle": config = { - "eagle_num_layers": 1 - } + "eagle_architecture_config": eagle3_default_config, + #use default_eagle_config for eagle1 + } mtsp.convert(model, [(mode, config)]) @@ -72,7 +74,6 @@ After converting to a speculative decoding model, you need to fine-tune the deco mto.enable_huggingface_checkpointing() trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module) - trainer._move_model_to_device(model, trainer.args.device) trainer.train(resume_from_checkpoint=checkpoint) trainer.save_state() diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 0555faedc..59f9617fa 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -244,23 +244,32 @@ To add a system prompt, use the `--system_prompt ` argument. For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support. +### Configuring Draft Model + +For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. using a different MLP intermediate size for the draft model: + +```json +{ + "intermediate_size": 28672 +} +``` + ### Draft Vocabulary Compression -We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: +For EAGLE‑1 and EAGLE‑3 we can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: ```bash python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache ``` -This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. - -### Configuring Draft Model +This will produce a `d2t.pt` file representing the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`: +To train with compressed draft vocab, set the following two fields in `eagle_config.json`: ```json { - "draft_vocab_size": 32000 + "draft_vocab_size": 32000, + "draft_vocab_cache_path": "" } ``` diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 35c9b8ede..2c75186a5 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -31,6 +31,7 @@ import json import os +from copy import deepcopy from dataclasses import dataclass, field from typing import Literal @@ -42,6 +43,7 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import default_eagle_config, eagle3_default_config from modelopt.torch.utils import print_rank_0 torch.manual_seed(0) @@ -70,10 +72,6 @@ class DataArguments: }, ) lazy_preprocess: bool = True - draft_vocab_cache_dir: str = field( - default="draft_vocab_cache", - metadata={"help": "Path to the d2t cache directory."}, - ) vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) @@ -176,53 +174,24 @@ def train(): } mtsp.convert(model, [("medusa", config)]) elif training_args.mode in ["eagle1", "eagle3"]: - from modelopt.torch.speculative.config import EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG - # Load default config - config = { - "eagle1": EAGLE1_DEFAULT_CFG, - "eagle3": EAGLE3_DEFAULT_CFG, - }[training_args.mode]["config"] + default_eagle_arch_cfg = { + "eagle1": deepcopy(default_eagle_config), + "eagle3": deepcopy(eagle3_default_config), + }[training_args.mode] - # overwrite config with custom config - if use_offline_training: - config["eagle_offline"] = True + config = { + "eagle_offline": use_offline_training, + "eagle_architecture_config": default_eagle_arch_cfg, + } + # Overwrite default config with custom config if eagle_args.eagle_config: with open(eagle_args.eagle_config) as f: custom_config = json.load(f) config["eagle_architecture_config"].update(custom_config) - # Hidden size and vocab size must match base model - llm_config = ( - model.config.llm_config if hasattr(model.config, "llm_config") else model.config - ) - config["eagle_architecture_config"].update( - { - "hidden_size": llm_config.hidden_size, - "vocab_size": llm_config.vocab_size, - # we also overwrite max_pos_embedding for deployment compatibility - "max_position_embeddings": llm_config.max_position_embeddings, - "draft_vocab_size": custom_config["draft_vocab_size"] - if eagle_args.eagle_config and "draft_vocab_size" in custom_config - else llm_config.vocab_size, - } - ) - mtsp.convert(model, [("eagle", config)]) - - # read draft vocab cache - if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: - try: - model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path)) - vocab_cache_path = os.path.join( - data_args.draft_vocab_cache_dir, model_name, "d2t.pt" - ) - vocab_cache = torch.load(vocab_cache_path) - model.eagle_module.d2t = vocab_cache - print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.") - except Exception as e: - raise e else: raise Exception(f"{training_args.mode} is not supported!") diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 1aed13e87..b4860230b 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -47,6 +47,7 @@ from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput +from ...utils import print_rank_0 from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask @@ -429,6 +430,25 @@ def _get_eagle_device(self): base_model_last_layer = self._base_model.layers[-1] return next(base_model_last_layer.parameters()).device + def _rewrite_eagle_cfg(self, eagle_arch_cfg): + """Overwrite necessary fields in eagle config to match target.""" + # hidden size, vocab size, max rope must match target + eagle_arch_cfg.update( + { + "hidden_size": self._base_llm_config.hidden_size, + "vocab_size": self._base_llm_config.vocab_size, + "max_position_embeddings": self._base_llm_config.max_position_embeddings, + "draft_vocab_size": eagle_arch_cfg.get( + "draft_vocab_size", self._base_llm_config.vocab_size + ), + } + ) + + if "_attn_implementation" not in eagle_arch_cfg: + eagle_arch_cfg["_attn_implementation"] = "sdpa" + + return eagle_arch_cfg + def modify( self, eagle_offline, @@ -445,6 +465,9 @@ def modify( Args: config: The config for eagle decoder layers. """ + # Overwrite eagle config to match target model + eagle_architecture_config = self._rewrite_eagle_cfg(eagle_architecture_config) + super().modify( eagle_offline=eagle_offline, eagle_hidden_state_distillation=eagle_hidden_state_distillation, @@ -456,8 +479,7 @@ def modify( eagle_architecture_config=eagle_architecture_config, ) self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config) - if self.eagle_config._attn_implementation is None: - self.eagle_config._attn_implementation = "sdpa" + decoder_cls = ( type(self.model.layers[-1]) if self.eagle_reuse_base_decoder else LlamaDecoderLayer ) @@ -470,13 +492,6 @@ def modify( ): self._set_default_aux_hidden_state_layers() - if self._base_llm_config.hidden_size != self.eagle_config.hidden_size: - raise ValueError( - "EAGLE module hidden size " - f"{self.eagle_config.hidden_size} must match base model hidden size " - f"{self._base_llm_config.hidden_size}!" - ) - self.eagle_module = EagleModule( self.eagle_config, decoder_cls, @@ -511,6 +526,16 @@ def modify( self.num_ttt_steps = 4 # NOTE: (hg) hardcoded for now. Might add to config later. self._cached_attn_blk_masks = {} + # load draft vocab cache + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + try: + self.eagle_module.d2t = torch.load(self.eagle_config.draft_vocab_cache_path) + print_rank_0( + f"Loaded draft vocab cache from {self.eagle_config.draft_vocab_cache_path}." + ) + except Exception as e: + raise ValueError(f"Failed to load draft vocab cache: {e}") + def _get_ttt_attention_mask(self, seq_length, ttt_step): # compile and cached flex attention masks in first call if ttt_step not in self._cached_attn_blk_masks: