Skip to content

Commit 8ea9ef7

Browse files
tianshubThe tunix Authors
authored andcommitted
remove step time metrics
PiperOrigin-RevId: 889401107
1 parent 560e7d0 commit 8ea9ef7

File tree

1 file changed

+1
-36
lines changed

1 file changed

+1
-36
lines changed

tunix/sft/peft_trainer.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,13 @@ class MetricsBuffer:
113113
step: The training step number.
114114
losses: A list of loss values recorded within this step (e.g., across
115115
gradient accumulation steps).
116-
step_time_deltas: A list of time deltas for each computation within this
117-
step.
118116
additional_metrics: Dictionary for storing additional metrics. The key is
119117
the metric name, and the value is a tuple containing a list of metric
120118
values and a callable to aggregate them.
121119
"""
122120

123121
step: int
124122
losses: List[ArrayLike]
125-
step_time_deltas: List[float]
126123
additional_metrics: Dict[
127124
str, Tuple[List[ArrayLike], Callable[[ArrayLike], ArrayLike]]
128125
] = dataclasses.field(default_factory=dict)
@@ -132,11 +129,6 @@ def loss(self):
132129
"""Returns the mean of the recorded losses for the step."""
133130
return np.mean(np.array([np.array(x) for x in self.losses]))
134131

135-
@property
136-
def step_time_delta(self):
137-
"""Returns the mean of the recorded step time deltas for the step."""
138-
return np.mean(self.step_time_deltas)
139-
140132

141133
def _calculate_global_batch_size(train_example: Any) -> int:
142134
"""Calculates the global batch size from a training example.
@@ -460,7 +452,6 @@ def _log_metrics(
460452
self,
461453
loss: ArrayLike,
462454
step: int | None = None,
463-
step_time_delta: float | None = None,
464455
additional_metrics: dict[str, ArrayLike] | None = None,
465456
):
466457
"""Logs the metrics to the metrics logger and console."""
@@ -478,21 +469,6 @@ def _log_metrics(
478469
self._mode,
479470
step,
480471
)
481-
if step_time_delta is not None:
482-
self.metrics_logger.log(
483-
self.metrics_prefix,
484-
"step_time_sec",
485-
step_time_delta,
486-
self._mode,
487-
step,
488-
)
489-
self.metrics_logger.log(
490-
self.metrics_prefix,
491-
"steps_per_sec",
492-
1.0 / (step_time_delta + 1e-9),
493-
self._mode,
494-
step,
495-
)
496472

497473
if self._mode == sft_metrics_logger.Mode.TRAIN:
498474
logging.info(
@@ -509,7 +485,6 @@ def _buffer_metrics(
509485
metrics_buffer: MetricsBuffer | None,
510486
loss: ArrayLike,
511487
step: int,
512-
step_time_delta: float = 0.0,
513488
additional_metrics: (
514489
dict[str, Tuple[ArrayLike, Callable[[ArrayLike], ArrayLike]]] | None
515490
) = None,
@@ -519,12 +494,10 @@ def _buffer_metrics(
519494
metrics_buffer = MetricsBuffer(
520495
step=step,
521496
losses=[loss],
522-
step_time_deltas=[step_time_delta],
523497
)
524498
else:
525499
assert metrics_buffer.step == step
526500
metrics_buffer.losses.append(loss)
527-
metrics_buffer.step_time_deltas.append(step_time_delta or 0)
528501
if additional_metrics is not None:
529502
for k, (v, op) in additional_metrics.items():
530503
if k not in metrics_buffer.additional_metrics:
@@ -548,7 +521,6 @@ def _write_train_metrics(self):
548521
self._tqdm_train_metrics,
549522
step=self._prev_buffered_train_metrics.step,
550523
loss=self._prev_buffered_train_metrics.loss,
551-
step_time=self._prev_buffered_train_metrics.step_time_delta,
552524
)
553525
self._prev_buffered_train_metrics = self._buffered_train_metrics
554526
self._buffered_train_metrics = None
@@ -564,7 +536,6 @@ def _to_np_array(v):
564536
self._log_metrics(
565537
loss=metrics_buffer.loss,
566538
step=metrics_buffer.step,
567-
step_time_delta=metrics_buffer.step_time_delta,
568539
additional_metrics={
569540
k: op(_to_np_array(v))
570541
for k, (
@@ -592,15 +563,14 @@ def _may_update_pbar(
592563
metrics: list[str],
593564
step: int | None = None,
594565
loss: ArrayLike | None = None,
595-
step_time: float | None = None,
596566
):
597567
"""Updates the progress bar with the given metrics if available."""
598568
if self._pbar is not None:
599569
self._pbar.update_metrics(metrics, self._mode, ndigits=3)
600570
self._pbar.update()
601571

602572
if self.training_hooks and self._mode == sft_metrics_logger.Mode.TRAIN:
603-
self.training_hooks.on_train_step_end(self, step, loss, step_time)
573+
self.training_hooks.on_train_step_end(self, step, loss, 0.0)
604574

605575
def train(
606576
self,
@@ -725,16 +695,11 @@ def train(
725695
span.device_end([train_loss])
726696
span_v2.async_end([train_loss])
727697

728-
current_time = time.perf_counter()
729-
step_time_delta = current_time - last_step_completion_time
730-
last_step_completion_time = current_time
731-
732698
self._throttler.add_computation(train_loss)
733699
self._buffered_train_metrics = self._buffer_metrics(
734700
self._buffered_train_metrics,
735701
loss=train_loss,
736702
step=self._train_steps,
737-
step_time_delta=step_time_delta,
738703
additional_metrics={"grad_norm": (grad_norm, np.mean)},
739704
)
740705
# NB: put this after self._buffer_metrics is important

0 commit comments

Comments
 (0)