Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/source/guides/5_speculative_decoding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Example usage:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.config import default_eagle_config, eagle3_default_config

# User-defined model
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
Expand All @@ -53,8 +54,9 @@ Example usage:
}
elif mode == "eagle":
config = {
"eagle_num_layers": 1
}
"eagle_architecture_config": eagle3_default_config,
#use default_eagle_config for eagle1
}
mtsp.convert(model, [(mode, config)])


Expand All @@ -72,7 +74,6 @@ After converting to a speculative decoding model, you need to fine-tune the deco
mto.enable_huggingface_checkpointing()

trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module)
trainer._move_model_to_device(model, trainer.args.device)

trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
Expand Down
21 changes: 15 additions & 6 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,23 +244,32 @@ To add a system prompt, use the `--system_prompt <system_prompt_text>` argument.

For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support.

### Configuring Draft Model

For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. using a different MLP intermediate size for the draft model:

```json
{
"intermediate_size": 28672
}
```

### Draft Vocabulary Compression

We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:
For EAGLE‑1 and EAGLE‑3 we can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:

```bash
python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache
```

This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.

### Configuring Draft Model
This will produce a `d2t.pt` file representing the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`.

For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`:
To train with compressed draft vocab, set the following two fields in `eagle_config.json`:

```json
{
"draft_vocab_size": 32000
"draft_vocab_size": 32000,
"draft_vocab_cache_path": "<path_to_d2t.pt>"
}
```

Expand Down
53 changes: 11 additions & 42 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import json
import os
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Literal

Expand All @@ -42,6 +43,7 @@

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.config import default_eagle_config, eagle3_default_config
from modelopt.torch.utils import print_rank_0

torch.manual_seed(0)
Expand Down Expand Up @@ -70,10 +72,6 @@ class DataArguments:
},
)
lazy_preprocess: bool = True
draft_vocab_cache_dir: str = field(
default="draft_vocab_cache",
metadata={"help": "Path to the d2t cache directory."},
)
vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."})
vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."})

Expand Down Expand Up @@ -176,53 +174,24 @@ def train():
}
mtsp.convert(model, [("medusa", config)])
elif training_args.mode in ["eagle1", "eagle3"]:
from modelopt.torch.speculative.config import EAGLE1_DEFAULT_CFG, EAGLE3_DEFAULT_CFG

# Load default config
config = {
"eagle1": EAGLE1_DEFAULT_CFG,
"eagle3": EAGLE3_DEFAULT_CFG,
}[training_args.mode]["config"]
default_eagle_arch_cfg = {
"eagle1": deepcopy(default_eagle_config),
"eagle3": deepcopy(eagle3_default_config),
}[training_args.mode]

# overwrite config with custom config
if use_offline_training:
config["eagle_offline"] = True
config = {
"eagle_offline": use_offline_training,
"eagle_architecture_config": default_eagle_arch_cfg,
}

# Overwrite default config with custom config
if eagle_args.eagle_config:
with open(eagle_args.eagle_config) as f:
custom_config = json.load(f)
config["eagle_architecture_config"].update(custom_config)

# Hidden size and vocab size must match base model
llm_config = (
model.config.llm_config if hasattr(model.config, "llm_config") else model.config
)
config["eagle_architecture_config"].update(
{
"hidden_size": llm_config.hidden_size,
"vocab_size": llm_config.vocab_size,
# we also overwrite max_pos_embedding for deployment compatibility
"max_position_embeddings": llm_config.max_position_embeddings,
"draft_vocab_size": custom_config["draft_vocab_size"]
if eagle_args.eagle_config and "draft_vocab_size" in custom_config
else llm_config.vocab_size,
}
)

mtsp.convert(model, [("eagle", config)])

# read draft vocab cache
if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size:
try:
model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path))
vocab_cache_path = os.path.join(
data_args.draft_vocab_cache_dir, model_name, "d2t.pt"
)
vocab_cache = torch.load(vocab_cache_path)
model.eagle_module.d2t = vocab_cache
print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.")
except Exception as e:
raise e
else:
raise Exception(f"{training_args.mode} is not supported!")

Expand Down
43 changes: 34 additions & 9 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from transformers.trainer_pt_utils import LabelSmoother
from transformers.utils import ModelOutput

from ...utils import print_rank_0
from ..eagle.conversion import EagleDMRegistry
from ..eagle.eagle_model import EagleModel
from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask
Expand Down Expand Up @@ -429,6 +430,25 @@ def _get_eagle_device(self):
base_model_last_layer = self._base_model.layers[-1]
return next(base_model_last_layer.parameters()).device

def _rewrite_eagle_cfg(self, eagle_arch_cfg):
"""Overwrite necessary fields in eagle config to match target."""
# hidden size, vocab size, max rope must match target
eagle_arch_cfg.update(
{
"hidden_size": self._base_llm_config.hidden_size,
"vocab_size": self._base_llm_config.vocab_size,
"max_position_embeddings": self._base_llm_config.max_position_embeddings,
"draft_vocab_size": eagle_arch_cfg.get(
"draft_vocab_size", self._base_llm_config.vocab_size
),
}
)

if "_attn_implementation" not in eagle_arch_cfg:
eagle_arch_cfg["_attn_implementation"] = "sdpa"

return eagle_arch_cfg

def modify(
self,
eagle_offline,
Expand All @@ -445,6 +465,9 @@ def modify(
Args:
config: The config for eagle decoder layers.
"""
# Overwrite eagle config to match target model
eagle_architecture_config = self._rewrite_eagle_cfg(eagle_architecture_config)

super().modify(
eagle_offline=eagle_offline,
eagle_hidden_state_distillation=eagle_hidden_state_distillation,
Expand All @@ -456,8 +479,7 @@ def modify(
eagle_architecture_config=eagle_architecture_config,
)
self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
if self.eagle_config._attn_implementation is None:
self.eagle_config._attn_implementation = "sdpa"

decoder_cls = (
type(self.model.layers[-1]) if self.eagle_reuse_base_decoder else LlamaDecoderLayer
)
Expand All @@ -470,13 +492,6 @@ def modify(
):
self._set_default_aux_hidden_state_layers()

if self._base_llm_config.hidden_size != self.eagle_config.hidden_size:
raise ValueError(
"EAGLE module hidden size "
f"{self.eagle_config.hidden_size} must match base model hidden size "
f"{self._base_llm_config.hidden_size}!"
)

self.eagle_module = EagleModule(
self.eagle_config,
decoder_cls,
Expand Down Expand Up @@ -511,6 +526,16 @@ def modify(
self.num_ttt_steps = 4 # NOTE: (hg) hardcoded for now. Might add to config later.
self._cached_attn_blk_masks = {}

# load draft vocab cache
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
try:
self.eagle_module.d2t = torch.load(self.eagle_config.draft_vocab_cache_path)
print_rank_0(
f"Loaded draft vocab cache from {self.eagle_config.draft_vocab_cache_path}."
)
except Exception as e:
raise ValueError(f"Failed to load draft vocab cache: {e}")

def _get_ttt_attention_mask(self, seq_length, ttt_step):
# compile and cached flex attention masks in first call
if ttt_step not in self._cached_attn_blk_masks:
Expand Down