@@ -113,16 +113,13 @@ class MetricsBuffer:
113113 step: The training step number.
114114 losses: A list of loss values recorded within this step (e.g., across
115115 gradient accumulation steps).
116- step_time_deltas: A list of time deltas for each computation within this
117- step.
118116 additional_metrics: Dictionary for storing additional metrics. The key is
119117 the metric name, and the value is a tuple containing a list of metric
120118 values and a callable to aggregate them.
121119 """
122120
123121 step : int
124122 losses : List [ArrayLike ]
125- step_time_deltas : List [float ]
126123 additional_metrics : Dict [
127124 str , Tuple [List [ArrayLike ], Callable [[ArrayLike ], ArrayLike ]]
128125 ] = dataclasses .field (default_factory = dict )
@@ -132,11 +129,6 @@ def loss(self):
132129 """Returns the mean of the recorded losses for the step."""
133130 return np .mean (np .array ([np .array (x ) for x in self .losses ]))
134131
135- @property
136- def step_time_delta (self ):
137- """Returns the mean of the recorded step time deltas for the step."""
138- return np .mean (self .step_time_deltas )
139-
140132
141133def _calculate_global_batch_size (train_example : Any ) -> int :
142134 """Calculates the global batch size from a training example.
@@ -460,7 +452,6 @@ def _log_metrics(
460452 self ,
461453 loss : ArrayLike ,
462454 step : int | None = None ,
463- step_time_delta : float | None = None ,
464455 additional_metrics : dict [str , ArrayLike ] | None = None ,
465456 ):
466457 """Logs the metrics to the metrics logger and console."""
@@ -478,21 +469,6 @@ def _log_metrics(
478469 self ._mode ,
479470 step ,
480471 )
481- if step_time_delta is not None :
482- self .metrics_logger .log (
483- self .metrics_prefix ,
484- "step_time_sec" ,
485- step_time_delta ,
486- self ._mode ,
487- step ,
488- )
489- self .metrics_logger .log (
490- self .metrics_prefix ,
491- "steps_per_sec" ,
492- 1.0 / (step_time_delta + 1e-9 ),
493- self ._mode ,
494- step ,
495- )
496472
497473 if self ._mode == sft_metrics_logger .Mode .TRAIN :
498474 logging .info (
@@ -509,7 +485,6 @@ def _buffer_metrics(
509485 metrics_buffer : MetricsBuffer | None ,
510486 loss : ArrayLike ,
511487 step : int ,
512- step_time_delta : float = 0.0 ,
513488 additional_metrics : (
514489 dict [str , Tuple [ArrayLike , Callable [[ArrayLike ], ArrayLike ]]] | None
515490 ) = None ,
@@ -519,12 +494,10 @@ def _buffer_metrics(
519494 metrics_buffer = MetricsBuffer (
520495 step = step ,
521496 losses = [loss ],
522- step_time_deltas = [step_time_delta ],
523497 )
524498 else :
525499 assert metrics_buffer .step == step
526500 metrics_buffer .losses .append (loss )
527- metrics_buffer .step_time_deltas .append (step_time_delta or 0 )
528501 if additional_metrics is not None :
529502 for k , (v , op ) in additional_metrics .items ():
530503 if k not in metrics_buffer .additional_metrics :
@@ -548,7 +521,6 @@ def _write_train_metrics(self):
548521 self ._tqdm_train_metrics ,
549522 step = self ._prev_buffered_train_metrics .step ,
550523 loss = self ._prev_buffered_train_metrics .loss ,
551- step_time = self ._prev_buffered_train_metrics .step_time_delta ,
552524 )
553525 self ._prev_buffered_train_metrics = self ._buffered_train_metrics
554526 self ._buffered_train_metrics = None
@@ -564,7 +536,6 @@ def _to_np_array(v):
564536 self ._log_metrics (
565537 loss = metrics_buffer .loss ,
566538 step = metrics_buffer .step ,
567- step_time_delta = metrics_buffer .step_time_delta ,
568539 additional_metrics = {
569540 k : op (_to_np_array (v ))
570541 for k , (
@@ -592,15 +563,14 @@ def _may_update_pbar(
592563 metrics : list [str ],
593564 step : int | None = None ,
594565 loss : ArrayLike | None = None ,
595- step_time : float | None = None ,
596566 ):
597567 """Updates the progress bar with the given metrics if available."""
598568 if self ._pbar is not None :
599569 self ._pbar .update_metrics (metrics , self ._mode , ndigits = 3 )
600570 self ._pbar .update ()
601571
602572 if self .training_hooks and self ._mode == sft_metrics_logger .Mode .TRAIN :
603- self .training_hooks .on_train_step_end (self , step , loss , step_time )
573+ self .training_hooks .on_train_step_end (self , step , loss , 0.0 )
604574
605575 def train (
606576 self ,
@@ -725,16 +695,11 @@ def train(
725695 span .device_end ([train_loss ])
726696 span_v2 .async_end ([train_loss ])
727697
728- current_time = time .perf_counter ()
729- step_time_delta = current_time - last_step_completion_time
730- last_step_completion_time = current_time
731-
732698 self ._throttler .add_computation (train_loss )
733699 self ._buffered_train_metrics = self ._buffer_metrics (
734700 self ._buffered_train_metrics ,
735701 loss = train_loss ,
736702 step = self ._train_steps ,
737- step_time_delta = step_time_delta ,
738703 additional_metrics = {"grad_norm" : (grad_norm , np .mean )},
739704 )
740705 # NB: put this after self._buffer_metrics is important
0 commit comments