Skip to content

Commit d206ca3

Browse files
yMayanandthinkin-machineBordacarmocca
authored
added support for logging in different trainer stages (#16002)
Co-authored-by: thinkin-machine <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent ea1899e commit d206ca3

File tree

3 files changed

+94
-2
lines changed

3 files changed

+94
-2
lines changed

src/lightning_app/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
- The LoadBalancer now uses internal ip + port instead of URL exposed ([#16119](https://github.com/Lightning-AI/lightning/pull/16119))
1818

19+
- Added support for logging in different trainer stages with `DeviceStatsMonitor`
20+
([#16002](https://github.com/Lightning-AI/lightning/pull/16002))
21+
1922

2023
### Deprecated
2124

src/pytorch_lightning/callbacks/device_stats_monitor.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030

3131
class DeviceStatsMonitor(Callback):
3232
r"""
33-
Automatically monitors and logs device stats during training stage. ``DeviceStatsMonitor``
34-
is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.
33+
Automatically monitors and logs device stats during training, validation and testing stage.
34+
``DeviceStatsMonitor`` is a special callback as it requires a ``logger`` to passed as argument
35+
to the ``Trainer``.
3536
3637
Args:
3738
cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU.
@@ -109,6 +110,38 @@ def on_train_batch_end(
109110
) -> None:
110111
self._get_and_log_device_stats(trainer, "on_train_batch_end")
111112

113+
def on_validation_batch_start(
114+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
115+
) -> None:
116+
self._get_and_log_device_stats(trainer, "on_validation_batch_start")
117+
118+
def on_validation_batch_end(
119+
self,
120+
trainer: "pl.Trainer",
121+
pl_module: "pl.LightningModule",
122+
outputs: Optional[STEP_OUTPUT],
123+
batch: Any,
124+
batch_idx: int,
125+
dataloader_idx: int,
126+
) -> None:
127+
self._get_and_log_device_stats(trainer, "on_validation_batch_end")
128+
129+
def on_test_batch_start(
130+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
131+
) -> None:
132+
self._get_and_log_device_stats(trainer, "on_test_batch_start")
133+
134+
def on_test_batch_end(
135+
self,
136+
trainer: "pl.Trainer",
137+
pl_module: "pl.LightningModule",
138+
outputs: Optional[STEP_OUTPUT],
139+
batch: Any,
140+
batch_idx: int,
141+
dataloader_idx: int,
142+
) -> None:
143+
self._get_and_log_device_stats(trainer, "on_test_batch_end")
144+
112145

113146
def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
114147
return {prefix + separator + k: v for k, v in metrics_dict.items()}

tests/tests_pytorch/callbacks/test_device_stats_monitor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import csv
1415
import os
16+
import re
1517
from typing import Dict, Optional
1618
from unittest import mock
1719
from unittest.mock import Mock
@@ -166,3 +168,57 @@ def test_device_stats_monitor_warning_when_psutil_not_available(monkeypatch, tmp
166168
# TODO: raise an exception from v1.9
167169
with pytest.warns(UserWarning, match="psutil` is not installed"):
168170
monitor.setup(trainer, Mock(), "fit")
171+
172+
173+
def test_device_stats_monitor_logs_for_different_stages(tmpdir):
174+
"""Test that metrics are logged for all stages that is training, testing and validation."""
175+
176+
model = BoringModel()
177+
device_stats = DeviceStatsMonitor()
178+
179+
trainer = Trainer(
180+
default_root_dir=tmpdir,
181+
max_epochs=1,
182+
limit_train_batches=4,
183+
limit_val_batches=4,
184+
limit_test_batches=1,
185+
log_every_n_steps=1,
186+
accelerator="cpu",
187+
devices=1,
188+
callbacks=[device_stats],
189+
logger=CSVLogger(tmpdir),
190+
enable_checkpointing=False,
191+
enable_progress_bar=False,
192+
)
193+
194+
# training and validation stages will run
195+
trainer.fit(model)
196+
197+
with open(f"{tmpdir}/lightning_logs/version_0/metrics.csv") as csvfile:
198+
199+
content = csv.reader(csvfile, delimiter=",")
200+
it = iter(content).__next__()
201+
202+
# searching for training stage logs
203+
train_stage_results = [re.match(r".+on_train_batch", i) for i in it]
204+
train = any(train_stage_results)
205+
assert train, "training stage logs not found"
206+
207+
# searching for validation stage logs
208+
validation_stage_results = [re.match(r".+on_validation_batch", i) for i in it]
209+
valid = any(validation_stage_results)
210+
assert valid, "validation stage logs not found"
211+
212+
# testing stage will run
213+
trainer.test(model)
214+
215+
with open(f"{tmpdir}/lightning_logs/version_0/metrics.csv") as csvfile:
216+
217+
content = csv.reader(csvfile, delimiter=",")
218+
it = iter(content).__next__()
219+
220+
# searching for testing stage logs
221+
test_stage_results = [re.match(r".+on_test_batch", i) for i in it]
222+
test = any(test_stage_results)
223+
224+
assert test, "testing stage logs not found"

0 commit comments

Comments
 (0)