Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@

import torch
import transformers
from ar_validate import validate_ar
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import TrainerCallback
from transformers.trainer_pt_utils import LabelSmoother

from modelopt.torch.utils import print_rank_0

try:
import wandb

wandb.init()
except ImportError:
wandb = None
Comment on lines +30 to +35
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Premature wandb.init() call in import block.

Calling wandb.init() at module import time (line 33) is problematic because:

  1. It initializes wandb even if ARValidationCallback is never used
  2. It happens before the user can configure wandb settings
  3. It may conflict with other wandb initialization in the application

Apply this diff to defer initialization:

 try:
     import wandb
-
-    wandb.init()
 except ImportError:
     wandb = None

Then in the callback, check if wandb is initialized before logging:

             print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
-            if wandb:
+            if wandb and wandb.run:
                 wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
import wandb
wandb.init()
except ImportError:
wandb = None
# At the top of examples/speculative_decoding/eagle_utils.py, adjust the wandb import block:
try:
import wandb
except ImportError:
wandb = None
# … later in the file, inside ARValidationCallback (around the print_rank_0 call):
def on_evaluate(self, args, state, control, metrics=None, logs=None):
# existing metric computation…
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb and wandb.run:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
# rest of callback…


IGNORE_TOKEN_ID = LabelSmoother.ignore_index

REMOVE_THINK_CHAT_TEMPLATE = (
Expand Down Expand Up @@ -382,3 +392,24 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
}

return batch


class ARValidationCallback(TrainerCallback):
def __init__(self, ar_validate_steps: int = 1000):
self.ar_validate_steps = ar_validate_steps

def on_step_end(self, args, state, control, **kwargs):
if self.ar_validate_steps <= 0:
return control
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
print_rank_0("Running AR validation...")
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
return control
37 changes: 3 additions & 34 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,15 @@

import torch
import transformers
from ar_validate import validate_ar
from datasets import load_dataset
from eagle_utils import make_eagle_supervised_data_module
from eagle_utils import ARValidationCallback, make_eagle_supervised_data_module
from medusa_utils import make_medusa_supervised_data_module
from transformers import Trainer, TrainerCallback
from transformers import Trainer
from transformers.trainer_utils import get_last_checkpoint

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.utils import print_rank_0

try:
import wandb

wandb.init()
except ImportError:
wandb = None

torch.manual_seed(0)
mto.enable_huggingface_checkpointing()

Expand Down Expand Up @@ -147,9 +138,8 @@ def train():
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto")
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
else:
model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, torch_dtype="auto", **model_kwargs
model_args.model_name_or_path, torch_dtype="auto", device_map="cpu"
)
if use_offline_training:
# When doing offline training, we need to set num_hidden_layers
Expand Down Expand Up @@ -231,34 +221,13 @@ def train():
tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
)

class ARValidationCallback(TrainerCallback):
def __init__(self, ar_validate_steps: int = 500):
self.ar_validate_steps = ar_validate_steps

def on_step_end(self, args, state, control, **kwargs):
if self.ar_validate_steps <= 0:
return control
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
print_rank_0("Running AR validation...")
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
return control

trainer = Trainer(
model=model,
processing_class=tokenizer,
args=training_args,
callbacks=[ARValidationCallback(training_args.ar_validate_steps)],
**data_module,
)
trainer._move_model_to_device(model, trainer.args.device)

# Manually enable this to return loss in eval
trainer.can_return_loss = True
Expand Down
30 changes: 21 additions & 9 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask
from ..medusa.conversion import MedusaDMRegistry
from ..medusa.medusa_model import MedusaModel
from ..utils import AcceptanceRateValidation, ResBlock
from ..utils import AcceptanceRateValidation, ResBlock, temporary_set_config_value

IGNORE_TOKEN_ID = LabelSmoother.ignore_index

Expand Down Expand Up @@ -445,12 +445,20 @@ def modify(
param.requires_grad = False

# EAGLE-3 auxiliary hidden_states
if self.eagle_config.use_aux_hidden_state:
if (not eagle_offline) and self.eagle_config.use_aux_hidden_state:
self._aux_hidden_states = []
for layer_idx, layer in enumerate(self.model.layers):
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)

# delete base model layers for offline training
if eagle_offline:
self.model._modules.pop("layers")

# NOTE: this is a temporary hack to bypass hf trainer check:
# https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
self.is_quantized = False

self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
self._cached_attn_blk_masks = []

Expand Down Expand Up @@ -907,13 +915,17 @@ def pseudo_speculative_generate(
eagle_input_hidden_states, eagle_position_ids
)

_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
eagle_input_hidden_states,
self.model.embed_tokens(eagle_ids),
eagle_attention_mask,
eagle_position_ids,
position_embeddings,
)
# Use SDPA attention during generation for both stability and performance
with temporary_set_config_value(
self.eagle_module.config, "_attn_implementation", "sdpa"
):
_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
eagle_input_hidden_states,
self.model.embed_tokens(eagle_ids),
eagle_attention_mask,
eagle_position_ids,
position_embeddings,
)

draft_token = eagle_logits[:, -1:, :].argmax(dim=-1)
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
Expand Down
14 changes: 14 additions & 0 deletions modelopt/torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Utils for speculative decoding."""

import contextlib
import copy
import warnings
from collections import Counter, defaultdict, deque
Expand Down Expand Up @@ -362,3 +363,16 @@ def validate(
ar = (ground_truth.shape[1] - isl) / cnt

return ground_truth, ar


@contextlib.contextmanager
def temporary_set_config_value(config, field, value):
"""Context manager to temporarily change config value."""
if not hasattr(config, field):
raise AttributeError(f"Config does not have field '{field}'")
original_value = getattr(config, field)
try:
setattr(config, field, value)
yield
finally:
setattr(config, field, original_value)
Loading