Skip to content

Commit f4ee1c7

Browse files
committed
fix implementation
1 parent e0d4700 commit f4ee1c7

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
self._t0s: dict[RunningStage, float] = {}
8989
self._lengths: dict[RunningStage, int] = {}
9090
self._samples: dict[RunningStage, int] = {}
91+
self._batches: dict[RunningStage, int] = {}
9192

9293
@override
9394
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
@@ -107,10 +108,14 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
107108
def _start(self, trainer: "Trainer") -> None:
108109
stage = trainer.state.stage
109110
assert stage is not None
110-
self._throughputs[stage].reset()
111-
self._lengths[stage] = 0
111+
112+
if stage not in self._samples:
113+
self._throughputs[stage].reset()
114+
self._lengths[stage] = 0
115+
self._samples[stage] = 0
116+
self._batches[stage] = 0
117+
112118
self._t0s[stage] = time.perf_counter()
113-
self._samples[stage] = 0
114119

115120
@torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads
116121
def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, iter_num: int) -> None:
@@ -136,10 +141,11 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
136141
flops_per_batch = None
137142

138143
self._samples[stage] += self.batch_size_fn(batch)
144+
self._batches[stage] += 1
139145

140146
throughput.update(
141147
time=elapsed,
142-
batches=iter_num,
148+
batches=self._batches[stage],
143149
# this assumes that all iterations used the same batch size
144150
samples=self._samples[stage],
145151
lengths=None if self.length_fn is None else self._lengths[stage],

0 commit comments

Comments
 (0)