|
16 | 16 | import collections
|
17 | 17 | import itertools
|
18 | 18 | from re import escape
|
| 19 | +from unittest import mock |
| 20 | +from unittest.mock import call |
19 | 21 |
|
20 | 22 | import numpy as np
|
21 | 23 | import pytest
|
@@ -747,3 +749,37 @@ def validation_epoch_end(self, *_) -> None:
|
747 | 749 | train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
|
748 | 750 | val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
|
749 | 751 | trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
|
| 752 | + |
| 753 | + |
| 754 | +@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") |
| 755 | +def test_log_metrics_epoch_step_values(mock_log_metrics, tmpdir): |
| 756 | + """Tests the default epoch and step values logged.""" |
| 757 | + |
| 758 | + class MyModel(BoringModel): |
| 759 | + def training_step(self, batch, batch_idx): |
| 760 | + self.log("foo", 0.0, on_step=True, on_epoch=True) |
| 761 | + return super().training_step(batch, batch_idx) |
| 762 | + |
| 763 | + model = MyModel() |
| 764 | + trainer = Trainer( |
| 765 | + default_root_dir=tmpdir, |
| 766 | + limit_train_batches=2, |
| 767 | + limit_val_batches=0, |
| 768 | + max_epochs=2, |
| 769 | + log_every_n_steps=1, |
| 770 | + enable_model_summary=False, |
| 771 | + enable_checkpointing=False, |
| 772 | + enable_progress_bar=False, |
| 773 | + ) |
| 774 | + trainer.fit(model) |
| 775 | + |
| 776 | + mock_log_metrics.assert_has_calls( |
| 777 | + [ |
| 778 | + call(metrics={"foo_step": 0.0, "epoch": 0}, step=0), |
| 779 | + call(metrics={"foo_step": 0.0, "epoch": 0}, step=1), |
| 780 | + call(metrics={"foo_epoch": 0.0, "epoch": 0}, step=1), |
| 781 | + call(metrics={"foo_step": 0.0, "epoch": 1}, step=2), |
| 782 | + call(metrics={"foo_step": 0.0, "epoch": 1}, step=3), |
| 783 | + call(metrics={"foo_epoch": 0.0, "epoch": 1}, step=3), |
| 784 | + ] |
| 785 | + ) |
0 commit comments