@@ -525,30 +525,33 @@ def compute_loss(self, *args, **kwargs):
525525class EagleTrainingPlot (TrainerCallback ):
526526 """Callback that plot training acc and AR during training."""
527527
528- def __init__ (self , ar_validate_steps : int = 1000 ):
528+ def __init__ (self , ar_validate_steps : int = 1000 , estimate_ar : bool = False ):
529529 self .ar_validate_steps = ar_validate_steps
530530 if wandb and is_master ():
531531 wandb .init ()
532+ self .estimate_ar = estimate_ar
532533
533534 def on_log (self , args , state , control , ** kwargs ):
534535 """Log training acc and estimate AR during log step."""
535536 if not hasattr (state , "training_accs" ):
536537 return control
537- # Calculate mean training AR since last log
538- # NOTE: This is only a estimate of the real AR.
539538 average_acc = np .mean (state .training_accs , axis = 0 )
540- est_ar = 1
541- acc_cumprod = 1
542- for step_acc in average_acc :
543- est_ar += acc_cumprod * step_acc
544- acc_cumprod *= step_acc
545- print_rank_0 (f"Step { state .global_step } Estimated Training AR: { est_ar :.4f} " )
539+ if self .estimate_ar :
540+ # Calculate mean training AR since last log
541+ # NOTE: This is only a estimate of the real AR.
542+ est_ar = 1
543+ acc_cumprod = 1
544+ for step_acc in average_acc :
545+ est_ar += acc_cumprod * step_acc
546+ acc_cumprod *= step_acc
547+ print_rank_0 (f"Step { state .global_step } Estimated Training AR: { est_ar :.4f} " )
546548
547549 # log to wandb
548550 if wandb and is_master ():
549551 for i , step_acc in enumerate (average_acc ):
550552 wandb .log ({f"step_{ i } _train_acc" : step_acc }, step = state .global_step )
551- wandb .log ({"estimated_training_ar" : est_ar }, step = state .global_step )
553+ if self .estimate_ar :
554+ wandb .log ({"estimated_training_ar" : est_ar }, step = state .global_step )
552555
553556 # reset training_accs
554557 state .training_accs = []
0 commit comments