14
14
"""Houses the methods used to set up the Trainer."""
15
15
16
16
from typing import Optional , Union
17
+ from datetime import timedelta
17
18
18
19
import lightning .pytorch as pl
19
20
from lightning .fabric .utilities .warnings import PossibleUserWarning
@@ -40,7 +41,7 @@ def _init_debugging_flags(
40
41
limit_predict_batches : Optional [Union [int , float ]],
41
42
fast_dev_run : Union [int , bool ],
42
43
overfit_batches : Union [int , float ],
43
- val_check_interval : Optional [Union [int , float ]],
44
+ val_check_interval : Optional [Union [int , float , str , timedelta , dict ]],
44
45
num_sanity_val_steps : int ,
45
46
) -> None :
46
47
# init debugging flags
@@ -69,6 +70,7 @@ def _init_debugging_flags(
69
70
trainer .num_sanity_val_steps = 0
70
71
trainer .fit_loop .max_epochs = 1
71
72
trainer .val_check_interval = 1.0
73
+ trainer ._val_check_time_interval = None # time not applicable in fast_dev_run
72
74
trainer .check_val_every_n_epoch = 1
73
75
trainer .loggers = [DummyLogger ()] if trainer .loggers else []
74
76
rank_zero_info (
@@ -82,7 +84,16 @@ def _init_debugging_flags(
82
84
trainer .limit_test_batches = _determine_batch_limits (limit_test_batches , "limit_test_batches" )
83
85
trainer .limit_predict_batches = _determine_batch_limits (limit_predict_batches , "limit_predict_batches" )
84
86
trainer .num_sanity_val_steps = float ("inf" ) if num_sanity_val_steps == - 1 else num_sanity_val_steps
85
- trainer .val_check_interval = _determine_batch_limits (val_check_interval , "val_check_interval" )
87
+ # Support time-based validation intervals:
88
+ # If `val_check_interval` is str/dict/timedelta, parse and store seconds on the trainer
89
+ # for the loops to consume.
90
+ trainer ._val_check_time_interval = None # default
91
+ if isinstance (val_check_interval , (str , dict , timedelta )):
92
+ trainer ._val_check_time_interval = _parse_time_interval_seconds (val_check_interval )
93
+ # Keep the numeric scheduler neutral; loops should check the time-based attribute.
94
+ trainer .val_check_interval = 1.0
95
+ else :
96
+ trainer .val_check_interval = _determine_batch_limits (val_check_interval , "val_check_interval" )
86
97
87
98
if overfit_batches_enabled :
88
99
trainer .limit_train_batches = overfit_batches
@@ -187,3 +198,30 @@ def _log_device_info(trainer: "pl.Trainer") -> None:
187
198
188
199
if HPUAccelerator .is_available () and not isinstance (trainer .accelerator , HPUAccelerator ):
189
200
rank_zero_warn ("HPU available but not used. You can set it by doing `Trainer(accelerator='hpu')`." )
201
+
202
+ def _parse_time_interval_seconds (value : Union [str , timedelta , dict ]) -> float :
203
+ if isinstance (value , timedelta ):
204
+ return value .total_seconds ()
205
+ if isinstance (value , dict ):
206
+ td = timedelta (** value )
207
+ return td .total_seconds ()
208
+ if isinstance (value , str ):
209
+ parts = value .split (":" )
210
+ if len (parts ) != 4 :
211
+ raise MisconfigurationException (
212
+ f"Invalid time format for `val_check_interval`: { value !r} . Expected 'DD:HH:MM:SS'."
213
+ )
214
+ d , h , m , s = parts
215
+ try :
216
+ days = int (d )
217
+ hours = int (h )
218
+ minutes = int (m )
219
+ seconds = int (s )
220
+ except ValueError :
221
+ raise MisconfigurationException (
222
+ f"Non-integer component in `val_check_interval` string: { value !r} . Use 'DD:HH:MM:SS'."
223
+ )
224
+ td = timedelta (days = days , hours = hours , minutes = minutes , seconds = seconds )
225
+ return td .total_seconds ()
226
+ # Should not happen given the caller guards
227
+ raise MisconfigurationException (f"Unsupported type for `val_check_interval`: { type (value )!r} " )
0 commit comments