Skip to content

Commit 655bb77

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f29ee07 commit 655bb77

File tree

8 files changed

+67
-57
lines changed

8 files changed

+67
-57
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
470470
# time-based validation: always defer saving to validation end
471471
if getattr(trainer, "_val_check_time_interval", None) is not None:
472472
return False
473-
473+
474474
# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded
475475
# so let's not enforce saving at every training epoch end
476476
if trainer.check_val_every_n_epoch != 1:

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
import os
1616
import shutil
1717
import sys
18+
import time
1819
from collections import ChainMap, OrderedDict, defaultdict
1920
from collections.abc import Iterable, Iterator
2021
from dataclasses import dataclass
2122
from typing import Any, Optional, Union
22-
import time
2323

2424
from lightning_utilities.core.apply_func import apply_to_collection
2525
from torch import Tensor
@@ -314,7 +314,7 @@ def on_run_end(self) -> list[_OUT_DICT]:
314314

315315
if self.verbose and self.trainer.is_global_zero:
316316
self._print_results(logged_outputs, self._stage.value)
317-
317+
318318
now = time.monotonic()
319319
self.trainer._last_val_time = now
320320

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15+
import time
1516
from dataclasses import dataclass
1617
from typing import Any, Optional, Union
17-
import time
1818

1919
import torch
2020
from typing_extensions import override
@@ -289,7 +289,7 @@ def setup_data(self) -> None:
289289
if getattr(trainer, "_val_check_time_interval", None) is not None:
290290
trainer.val_check_batch = None
291291
trainer._train_start_time = time.monotonic()
292-
trainer._last_val_time = trainer._train_start_time
292+
trainer._last_val_time = trainer._train_start_time
293293
elif isinstance(trainer.val_check_interval, int):
294294
trainer.val_check_batch = trainer.val_check_interval
295295
if trainer.val_check_batch > self.max_batches and trainer.check_val_every_n_epoch is not None:

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
# limitations under the License.
1414
import contextlib
1515
import math
16+
import time
1617
from collections import OrderedDict
1718
from dataclasses import dataclass
1819
from typing import Any, Optional, Union
19-
import time
2020

2121
import torch
2222
from typing_extensions import override

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import os
1515
from collections.abc import Iterable
1616
from dataclasses import dataclass, field
17-
from typing import Any, Optional, Union
1817
from datetime import timedelta
18+
from typing import Any, Optional, Union
1919

2020
import torch.multiprocessing as mp
2121
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler

src/lightning/pytorch/trainer/setup.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414
"""Houses the methods used to set up the Trainer."""
1515

16-
from typing import Optional, Union
1716
from datetime import timedelta
17+
from typing import Optional, Union
1818

1919
import lightning.pytorch as pl
2020
from lightning.fabric.utilities.warnings import PossibleUserWarning
@@ -199,29 +199,30 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
199199
if HPUAccelerator.is_available() and not isinstance(trainer.accelerator, HPUAccelerator):
200200
rank_zero_warn("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`.")
201201

202+
202203
def _parse_time_interval_seconds(value: Union[str, timedelta, dict]) -> float:
203-
if isinstance(value, timedelta):
204-
return value.total_seconds()
205-
if isinstance(value, dict):
206-
td = timedelta(**value)
207-
return td.total_seconds()
208-
if isinstance(value, str):
209-
parts = value.split(":")
210-
if len(parts) != 4:
211-
raise MisconfigurationException(
212-
f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'."
213-
)
214-
d, h, m, s = parts
215-
try:
216-
days = int(d)
217-
hours = int(h)
218-
minutes = int(m)
219-
seconds = int(s)
220-
except ValueError:
221-
raise MisconfigurationException(
222-
f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'."
223-
)
224-
td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
225-
return td.total_seconds()
226-
# Should not happen given the caller guards
227-
raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}")
204+
if isinstance(value, timedelta):
205+
return value.total_seconds()
206+
if isinstance(value, dict):
207+
td = timedelta(**value)
208+
return td.total_seconds()
209+
if isinstance(value, str):
210+
parts = value.split(":")
211+
if len(parts) != 4:
212+
raise MisconfigurationException(
213+
f"Invalid time format for `val_check_interval`: {value!r}. Expected 'DD:HH:MM:SS'."
214+
)
215+
d, h, m, s = parts
216+
try:
217+
days = int(d)
218+
hours = int(h)
219+
minutes = int(m)
220+
seconds = int(s)
221+
except ValueError:
222+
raise MisconfigurationException(
223+
f"Non-integer component in `val_check_interval` string: {value!r}. Use 'DD:HH:MM:SS'."
224+
)
225+
td = timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
226+
return td.total_seconds()
227+
# Should not happen given the caller guards
228+
raise MisconfigurationException(f"Unsupported type for `val_check_interval`: {type(value)!r}")

src/lightning/pytorch/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(
109109
limit_test_batches: Optional[Union[int, float]] = None,
110110
limit_predict_batches: Optional[Union[int, float]] = None,
111111
overfit_batches: Union[int, float] = 0.0,
112-
val_check_interval: Optional[Union[int, float, str, timedelta, dict[str,int]]] = None,
112+
val_check_interval: Optional[Union[int, float, str, timedelta, dict[str, int]]] = None,
113113
check_val_every_n_epoch: Optional[int] = 1,
114114
num_sanity_val_steps: Optional[int] = None,
115115
log_every_n_steps: Optional[int] = None,
@@ -212,7 +212,7 @@ def __init__(
212212
213213
check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``,
214214
validation will be done solely based on the number of training batches, requiring ``val_check_interval``
215-
to be an integer value. When used together with a time-based ``val_check_interval`` and
215+
to be an integer value. When used together with a time-based ``val_check_interval`` and
216216
``check_val_every_n_epoch`` > 1, validation is aligned to epoch multiples: if the interval elapses
217217
before the next multiple-N epoch, validation runs at the start of that epoch (after the first batch)
218218
and the timer resets; if it elapses during a multiple-N epoch, validation runs after the current batch.

tests/tests_pytorch/trainer/flags/test_val_check_interval.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
16-
import pytest
17-
import time
1815
import re
16+
import time
1917
from unittest.mock import patch
18+
19+
import pytest
2020
from torch.utils.data import DataLoader
2121

2222
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
@@ -132,17 +132,19 @@ def test_val_check_interval_float_with_none_check_val_every_n_epoch():
132132
with pytest.raises(
133133
MisconfigurationException,
134134
match=re.escape(
135-
"`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136-
"datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137-
)
135+
"`val_check_interval` should be an integer or a time-based duration (str 'DD:HH:MM:SS', "
136+
"datetime.timedelta, or dict kwargs for timedelta) when `check_val_every_n_epoch=None`."
137+
),
138138
):
139139
Trainer(
140140
val_check_interval=0.5,
141141
check_val_every_n_epoch=None,
142142
)
143143

144+
144145
def test_time_based_val_check_interval(tmp_path):
145146
call_count = {"count": 0}
147+
146148
def fake_time():
147149
result = call_count["count"]
148150
call_count["count"] += 2
@@ -168,17 +170,20 @@ def fake_time():
168170

169171

170172
@pytest.mark.parametrize(
171-
"check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description",
173+
("check_val_every_n_epoch", "val_check_interval", "epoch_duration", "expected_val_batches", "description"),
172174
[
173175
(None, "00:00:00:04", 2, [0, 1, 0, 1, 0], "val_check_interval timer only, no epoch gating"),
174176
(1, "00:00:00:06", 8, [1, 1, 2, 1, 1], "val_check_interval timer only, no epoch gating"),
175177
(2, "00:00:00:06", 9, [0, 2, 0, 2, 0], "epoch gating, timer longer than epoch"),
176178
(2, "00:00:00:20", 9, [0, 0, 0, 1, 0], "epoch gating, timer much longer"),
177179
(2, "00:00:00:03", 9, [0, 3, 0, 3, 0], "epoch gating, timer shorter than epoch"),
178-
]
180+
],
179181
)
180-
def test_time_and_epoch_gated_val_check(tmp_path, check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description):
182+
def test_time_and_epoch_gated_val_check(
183+
tmp_path, check_val_every_n_epoch, val_check_interval, epoch_duration, expected_val_batches, description
184+
):
181185
call_count = {"count": 0}
186+
182187
# Simulate time in steps (each batch is 1 second, epoch_duration=seconds per epoch)
183188
def fake_time():
184189
result = call_count["count"]
@@ -191,7 +196,11 @@ class TestModel(BoringModel):
191196
val_epoch_calls = 0
192197

193198
def on_train_batch_end(self, *args, **kwargs):
194-
if isinstance(self.trainer.check_val_every_n_epoch,int) and self.trainer.check_val_every_n_epoch > 1 and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
199+
if (
200+
isinstance(self.trainer.check_val_every_n_epoch, int)
201+
and self.trainer.check_val_every_n_epoch > 1
202+
and (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0
203+
):
195204
time.monotonic()
196205

197206
def on_train_epoch_end(self, *args, **kwargs):
@@ -205,17 +214,17 @@ def on_validation_epoch_start(self) -> None:
205214
max_steps = max_epochs * epoch_duration
206215
limit_train_batches = epoch_duration
207216

208-
trainer_kwargs = dict(
209-
default_root_dir=tmp_path,
210-
logger=False,
211-
enable_checkpointing=False,
212-
max_epochs=max_epochs,
213-
max_steps=max_steps,
214-
limit_val_batches=1,
215-
limit_train_batches=limit_train_batches,
216-
val_check_interval=val_check_interval,
217-
check_val_every_n_epoch=check_val_every_n_epoch
218-
)
217+
trainer_kwargs = {
218+
"default_root_dir": tmp_path,
219+
"logger": False,
220+
"enable_checkpointing": False,
221+
"max_epochs": max_epochs,
222+
"max_steps": max_steps,
223+
"limit_val_batches": 1,
224+
"limit_train_batches": limit_train_batches,
225+
"val_check_interval": val_check_interval,
226+
"check_val_every_n_epoch": check_val_every_n_epoch,
227+
}
219228

220229
with patch("time.monotonic", side_effect=fake_time):
221230
model = TestModel()
@@ -227,4 +236,4 @@ def on_validation_epoch_start(self) -> None:
227236
f"\nFAILED: {description}"
228237
f"\nExpected validation at batches: {expected_val_batches},"
229238
f"\nGot: {model.val_batches, model.val_epoch_calls}\n"
230-
)
239+
)

0 commit comments

Comments
 (0)