Skip to content

Commit c55bcf0

Browse files
authored
fix: eagle3 quantized base model (#383)
Signed-off-by: h-guo18 <[email protected]>
1 parent c9db0ce commit c55bcf0

File tree

4 files changed

+69
-43
lines changed

4 files changed

+69
-43
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,21 @@
1919

2020
import torch
2121
import transformers
22+
from ar_validate import validate_ar
23+
from datasets import load_dataset
2224
from torch.utils.data import Dataset
25+
from transformers import TrainerCallback
2326
from transformers.trainer_pt_utils import LabelSmoother
2427

2528
from modelopt.torch.utils import print_rank_0
2629

30+
try:
31+
import wandb
32+
33+
wandb.init()
34+
except ImportError:
35+
wandb = None
36+
2737
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
2838

2939
REMOVE_THINK_CHAT_TEMPLATE = (
@@ -382,3 +392,24 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
382392
}
383393

384394
return batch
395+
396+
397+
class ARValidationCallback(TrainerCallback):
398+
def __init__(self, ar_validate_steps: int = 1000):
399+
self.ar_validate_steps = ar_validate_steps
400+
401+
def on_step_end(self, args, state, control, **kwargs):
402+
if self.ar_validate_steps <= 0:
403+
return control
404+
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
405+
print_rank_0("Running AR validation...")
406+
ars = validate_ar(
407+
model=kwargs["model"],
408+
tokenizer=kwargs["processing_class"],
409+
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
410+
device=kwargs["model"].device,
411+
)
412+
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
413+
if wandb:
414+
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
415+
return control

examples/speculative_decoding/main.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,15 @@
3636

3737
import torch
3838
import transformers
39-
from ar_validate import validate_ar
40-
from datasets import load_dataset
41-
from eagle_utils import make_eagle_supervised_data_module
39+
from eagle_utils import ARValidationCallback, make_eagle_supervised_data_module
4240
from medusa_utils import make_medusa_supervised_data_module
43-
from transformers import Trainer, TrainerCallback
41+
from transformers import Trainer
4442
from transformers.trainer_utils import get_last_checkpoint
4543

4644
import modelopt.torch.opt as mto
4745
import modelopt.torch.speculative as mtsp
4846
from modelopt.torch.utils import print_rank_0
4947

50-
try:
51-
import wandb
52-
53-
wandb.init()
54-
except ImportError:
55-
wandb = None
56-
5748
torch.manual_seed(0)
5849
mto.enable_huggingface_checkpointing()
5950

@@ -147,9 +138,8 @@ def train():
147138
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto")
148139
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
149140
else:
150-
model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
151141
model = transformers.AutoModelForCausalLM.from_pretrained(
152-
model_args.model_name_or_path, torch_dtype="auto", **model_kwargs
142+
model_args.model_name_or_path, torch_dtype="auto", device_map="cpu"
153143
)
154144
if use_offline_training:
155145
# When doing offline training, we need to set num_hidden_layers
@@ -231,34 +221,13 @@ def train():
231221
tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
232222
)
233223

234-
class ARValidationCallback(TrainerCallback):
235-
def __init__(self, ar_validate_steps: int = 500):
236-
self.ar_validate_steps = ar_validate_steps
237-
238-
def on_step_end(self, args, state, control, **kwargs):
239-
if self.ar_validate_steps <= 0:
240-
return control
241-
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
242-
print_rank_0("Running AR validation...")
243-
ars = validate_ar(
244-
model=kwargs["model"],
245-
tokenizer=kwargs["processing_class"],
246-
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
247-
device=kwargs["model"].device,
248-
)
249-
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
250-
if wandb:
251-
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
252-
return control
253-
254224
trainer = Trainer(
255225
model=model,
256226
processing_class=tokenizer,
257227
args=training_args,
258228
callbacks=[ARValidationCallback(training_args.ar_validate_steps)],
259229
**data_module,
260230
)
261-
trainer._move_model_to_device(model, trainer.args.device)
262231

263232
# Manually enable this to return loss in eval
264233
trainer.can_return_loss = True

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask
5353
from ..medusa.conversion import MedusaDMRegistry
5454
from ..medusa.medusa_model import MedusaModel
55-
from ..utils import AcceptanceRateValidation, ResBlock
55+
from ..utils import AcceptanceRateValidation, ResBlock, temporary_set_config_value
5656

5757
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
5858

@@ -445,12 +445,20 @@ def modify(
445445
param.requires_grad = False
446446

447447
# EAGLE-3 auxiliary hidden_states
448-
if self.eagle_config.use_aux_hidden_state:
448+
if (not eagle_offline) and self.eagle_config.use_aux_hidden_state:
449449
self._aux_hidden_states = []
450450
for layer_idx, layer in enumerate(self.model.layers):
451451
if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids:
452452
layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook)
453453

454+
# delete base model layers for offline training
455+
if eagle_offline:
456+
self.model._modules.pop("layers")
457+
458+
# NOTE: this is a temporary hack to bypass hf trainer check:
459+
# https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
460+
self.is_quantized = False
461+
454462
self.num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
455463
self._cached_attn_blk_masks = []
456464

@@ -907,13 +915,17 @@ def pseudo_speculative_generate(
907915
eagle_input_hidden_states, eagle_position_ids
908916
)
909917

910-
_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
911-
eagle_input_hidden_states,
912-
self.model.embed_tokens(eagle_ids),
913-
eagle_attention_mask,
914-
eagle_position_ids,
915-
position_embeddings,
916-
)
918+
# Use SDPA attention during generation for both stability and performance
919+
with temporary_set_config_value(
920+
self.eagle_module.config, "_attn_implementation", "sdpa"
921+
):
922+
_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
923+
eagle_input_hidden_states,
924+
self.model.embed_tokens(eagle_ids),
925+
eagle_attention_mask,
926+
eagle_position_ids,
927+
position_embeddings,
928+
)
917929

918930
draft_token = eagle_logits[:, -1:, :].argmax(dim=-1)
919931
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:

modelopt/torch/speculative/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Utils for speculative decoding."""
1717

18+
import contextlib
1819
import copy
1920
import warnings
2021
from collections import Counter, defaultdict, deque
@@ -362,3 +363,16 @@ def validate(
362363
ar = (ground_truth.shape[1] - isl) / cnt
363364

364365
return ground_truth, ar
366+
367+
368+
@contextlib.contextmanager
369+
def temporary_set_config_value(config, field, value):
370+
"""Context manager to temporarily change config value."""
371+
if not hasattr(config, field):
372+
raise AttributeError(f"Config does not have field '{field}'")
373+
original_value = getattr(config, field)
374+
try:
375+
setattr(config, field, value)
376+
yield
377+
finally:
378+
setattr(config, field, original_value)

0 commit comments

Comments
 (0)