Skip to content

Commit d4d9e50

Browse files
committed
address review comments
Signed-off-by: h-guo18 <[email protected]>
1 parent 0e98540 commit d4d9e50

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -525,30 +525,33 @@ def compute_loss(self, *args, **kwargs):
525525
class EagleTrainingPlot(TrainerCallback):
526526
"""Callback that plot training acc and AR during training."""
527527

528-
def __init__(self, ar_validate_steps: int = 1000):
528+
def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False):
529529
self.ar_validate_steps = ar_validate_steps
530530
if wandb and is_master():
531531
wandb.init()
532+
self.estimate_ar = estimate_ar
532533

533534
def on_log(self, args, state, control, **kwargs):
534535
"""Log training acc and estimate AR during log step."""
535536
if not hasattr(state, "training_accs"):
536537
return control
537-
# Calculate mean training AR since last log
538-
# NOTE: This is only a estimate of the real AR.
539538
average_acc = np.mean(state.training_accs, axis=0)
540-
est_ar = 1
541-
acc_cumprod = 1
542-
for step_acc in average_acc:
543-
est_ar += acc_cumprod * step_acc
544-
acc_cumprod *= step_acc
545-
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
539+
if self.estimate_ar:
540+
# Calculate mean training AR since last log
541+
# NOTE: This is only a estimate of the real AR.
542+
est_ar = 1
543+
acc_cumprod = 1
544+
for step_acc in average_acc:
545+
est_ar += acc_cumprod * step_acc
546+
acc_cumprod *= step_acc
547+
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")
546548

547549
# log to wandb
548550
if wandb and is_master():
549551
for i, step_acc in enumerate(average_acc):
550552
wandb.log({f"step_{i}_train_acc": step_acc}, step=state.global_step)
551-
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)
553+
if self.estimate_ar:
554+
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)
552555

553556
# reset training_accs
554557
state.training_accs = []

examples/speculative_decoding/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def train():
143143
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto")
144144
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
145145
else:
146+
# To avoid OOM for large models, we load and convert model on CPU first.
147+
# Model will be moved to GPU during HF trainer.init().
146148
model = transformers.AutoModelForCausalLM.from_pretrained(
147149
model_args.model_name_or_path,
148150
torch_dtype="auto",

0 commit comments

Comments
 (0)