Skip to content

Commit a00a999

Browse files
awaelchlilantiga
authored andcommitted
Fix log_every_n_steps check in ThroughputMonitor (#19470)
1 parent eb90aa4 commit a00a999

File tree

3 files changed

+22
-37
lines changed

3 files changed

+22
-37
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
### Fixed
1111

1212
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))
13+
- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470))
1314

1415

1516
## [2.2.0] - 2024-02-08

src/lightning/pytorch/callbacks/throughput_monitor.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,10 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
9393
dtype = _plugin_to_compute_dtype(trainer.precision_plugin)
9494
self.available_flops = get_available_flops(trainer.strategy.root_device, dtype)
9595

96-
if stage == TrainerFn.FITTING:
97-
if trainer.accumulate_grad_batches % trainer.log_every_n_steps != 0:
98-
raise ValueError(
99-
"The `ThroughputMonitor` only logs when gradient accumulation is finished. You set"
100-
f" `Trainer(accumulate_grad_batches={trainer.accumulate_grad_batches},"
101-
f" log_every_n_steps={trainer.log_every_n_steps})` but these are not divisible and thus will not"
102-
" log anything."
103-
)
104-
105-
if trainer.enable_validation:
106-
# `fit` includes validation inside
107-
throughput = Throughput(
108-
available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs
109-
)
110-
self._throughputs[RunningStage.VALIDATING] = throughput
96+
if stage == TrainerFn.FITTING and trainer.enable_validation:
97+
# `fit` includes validation inside
98+
throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs)
99+
self._throughputs[RunningStage.VALIDATING] = throughput
111100

112101
throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs)
113102
stage = trainer.state.stage

tests/tests_pytorch/callbacks/test_throughput_monitor.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def test_throughput_monitor_fit_no_length_fn(tmp_path):
160160
]
161161

162162

163-
def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
163+
@pytest.mark.parametrize("log_every_n_steps", [1, 3])
164+
def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_path):
164165
logger_mock = Mock()
165166
logger_mock.save_dir = tmp_path
166167
monitor = ThroughputMonitor(length_fn=lambda x: 3 * 2, batch_size_fn=lambda x: 3, window_size=4, separator="|")
@@ -174,26 +175,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
174175
limit_train_batches=5,
175176
limit_val_batches=0,
176177
max_epochs=2,
177-
log_every_n_steps=3,
178+
log_every_n_steps=log_every_n_steps,
178179
accumulate_grad_batches=2,
179-
num_sanity_val_steps=2,
180-
enable_checkpointing=False,
181-
enable_model_summary=False,
182-
enable_progress_bar=False,
183-
)
184-
with pytest.raises(ValueError, match="not divisible"):
185-
trainer.fit(model)
186-
187-
trainer = Trainer(
188-
devices=1,
189-
logger=logger_mock,
190-
callbacks=monitor,
191-
limit_train_batches=5,
192-
limit_val_batches=0,
193-
max_epochs=2,
194-
log_every_n_steps=1,
195-
accumulate_grad_batches=2,
196-
num_sanity_val_steps=2,
197180
enable_checkpointing=False,
198181
enable_model_summary=False,
199182
enable_progress_bar=False,
@@ -211,9 +194,19 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
211194
"train|device|flops_per_sec": 10.0,
212195
"train|device|mfu": 0.1,
213196
}
214-
assert logger_mock.log_metrics.mock_calls == [
197+
198+
all_log_calls = [
215199
call(
216-
metrics={"train|time": 2.5, "train|batches": 2, "train|samples": 6, "train|lengths": 12, "epoch": 0}, step=0
200+
metrics={
201+
# The very first batch doesn't have the *_per_sec metrics yet
202+
**(expected if log_every_n_steps > 1 else {}),
203+
"train|time": 2.5,
204+
"train|batches": 2,
205+
"train|samples": 6,
206+
"train|lengths": 12,
207+
"epoch": 0,
208+
},
209+
step=0,
217210
),
218211
call(
219212
metrics={
@@ -271,6 +264,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
271264
step=5,
272265
),
273266
]
267+
expected_log_calls = all_log_calls[(log_every_n_steps - 1) :: log_every_n_steps]
268+
assert logger_mock.log_metrics.mock_calls == expected_log_calls
274269

275270

276271
@pytest.mark.parametrize("fn", ["validate", "test", "predict"])

0 commit comments

Comments
 (0)