Skip to content

Commit 88b26be

Browse files
authored
Support variable batch size in throughput callback
1 parent f3f10d4 commit 88b26be

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
self._throughputs: Dict[RunningStage, Throughput] = {}
8888
self._t0s: Dict[RunningStage, float] = {}
8989
self._lengths: Dict[RunningStage, int] = {}
90+
self._samples: Dict[RunningStage, int] = {}
9091

9192
@override
9293
def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
@@ -109,6 +110,7 @@ def _start(self, trainer: "Trainer") -> None:
109110
self._throughputs[stage].reset()
110111
self._lengths[stage] = 0
111112
self._t0s[stage] = time.perf_counter()
113+
self._samples[stage] = 0
112114

113115
@torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads
114116
def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, iter_num: int) -> None:
@@ -133,12 +135,13 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any,
133135
)
134136
flops_per_batch = None
135137

136-
batch_size = self.batch_size_fn(batch)
138+
self._samples[stage] += self.batch_size_fn(batch)
139+
137140
throughput.update(
138141
time=elapsed,
139142
batches=iter_num,
140143
# this assumes that all iterations used the same batch size
141-
samples=iter_num * batch_size,
144+
samples=self._samples[stage],
142145
lengths=None if self.length_fn is None else self._lengths[stage],
143146
flops=flops_per_batch,
144147
)

0 commit comments

Comments
 (0)