@@ -120,7 +120,8 @@ class Trainer(Recipe, Generic[BatchT]):
120120 _seed : int
121121 _max_num_steps : int | None
122122 _max_num_data_epochs : int | None
123- _validator : Validator
123+ _validator : Validator | None
124+ _validate_at_start : bool
124125 _validate_after_n_steps : int
125126 _validate_every_n_steps : int | None
126127 _validate_after_n_data_epochs : int
@@ -181,6 +182,7 @@ def __init__(
181182 anomaly_detection : bool = False ,
182183 max_num_steps : int | None = None ,
183184 max_num_data_epochs : int | None = None ,
185+ validate_at_start : bool = False ,
184186 validate_after_n_steps : int = 0 ,
185187 validate_every_n_steps : int | None = None ,
186188 validate_after_n_data_epochs : int = 0 ,
@@ -328,6 +330,8 @@ def __init__(
328330
329331 self ._validator = validator
330332
333+ self ._validate_at_start = validate_at_start
334+
331335 if validate_every_n_steps is not None :
332336 if validate_every_n_steps <= 0 :
333337 raise ValueError (
@@ -507,6 +511,9 @@ def _maybe_restore_state(self) -> _TrainerState:
507511 ) from ex
508512
509513 if step_nr is None :
514+ if self ._validate_at_start :
515+ return _TrainerState .PRE_VALIDATION
516+
510517 return _TrainerState .DATA_LOAD
511518
512519 log .info ("Restoring training from the last checkpoint at step {}." , step_nr )
@@ -574,6 +581,9 @@ def _do_run(self) -> None:
574581 with progress_task , self ._lapse_watch :
575582 while self ._state != _TrainerState .STOPPED :
576583 match self ._state :
584+ case _TrainerState .PRE_VALIDATION :
585+ self ._state = self ._pre_validate ()
586+
577587 case _TrainerState .DATA_LOAD :
578588 self ._state = self ._read_next_batches ()
579589
@@ -607,6 +617,12 @@ def _do_run(self) -> None:
607617
608618 self ._state = self ._stop ()
609619
620+ def _pre_validate (self ) -> _TrainerState :
621+ if self ._validate is not None :
622+ self ._validate ()
623+
624+ return _TrainerState .DATA_LOAD
625+
610626 def _read_next_batches (self ) -> _TrainerState :
611627 with self ._data_watch :
612628 try :
@@ -1045,6 +1061,9 @@ def _maybe_validate(self) -> float | None:
10451061 return score
10461062
10471063 def _should_validate (self ) -> bool :
1064+ if self ._validator is None :
1065+ return False
1066+
10481067 return self ._should_do (
10491068 self ._validate_after_n_steps ,
10501069 self ._validate_every_n_steps ,
@@ -1053,7 +1072,13 @@ def _should_validate(self) -> bool:
10531072 )
10541073
10551074 def _validate (self ) -> float | None :
1056- log .info ("Starting validation after step {}." , self ._step_nr )
1075+ if self ._validator is None :
1076+ raise InternalError ("`_validator` is `None`." )
1077+
1078+ if self ._step_nr == 0 :
1079+ log .info ("Starting pre-validation before training." )
1080+ else :
1081+ log .info ("Starting validation after step {}." , self ._step_nr )
10571082
10581083 self ._model .module .eval ()
10591084
@@ -1135,9 +1160,12 @@ def _should_do(
11351160 after_n_data_epochs : int ,
11361161 every_n_data_epochs : int | None ,
11371162 ) -> bool :
1163+ if self ._state == _TrainerState .PRE_VALIDATION :
1164+ return False
1165+
11381166 def should_do_at_step () -> bool :
11391167 if every_n_steps is not None :
1140- if self ._step_nr >= after_n_steps :
1168+ if self ._step_nr > after_n_steps :
11411169 if self ._step_nr % every_n_steps == 0 :
11421170 return True
11431171
@@ -1158,7 +1186,7 @@ def should_do_at_step() -> bool:
11581186
11591187 if self ._state == _TrainerState .END_OF_DATA_EPOCH :
11601188 if every_n_data_epochs is not None :
1161- if self ._data_epoch_nr >= after_n_data_epochs :
1189+ if self ._data_epoch_nr > after_n_data_epochs :
11621190 if self ._data_epoch_nr % every_n_data_epochs == 0 :
11631191 already_done = should_do_at_step ()
11641192
@@ -1191,16 +1219,17 @@ def step_nr(self) -> int:
11911219
11921220class _TrainerState (Enum ):
11931221 NOT_STARTED = 0
1194- DATA_LOAD = 1
1195- STEP = 2
1196- POST_STEP = 3
1197- END_OF_DATA_EPOCH = 4
1198- END_OF_TRAINING = 5
1199- END_OF_DATA = 6
1200- GRADIENT_OVERFLOW = 7
1201- EARLY_STOP = 8
1202- STOP_REQUESTED = 9
1203- STOPPED = 10
1222+ PRE_VALIDATION = 1
1223+ DATA_LOAD = 2
1224+ STEP = 3
1225+ POST_STEP = 4
1226+ END_OF_DATA_EPOCH = 5
1227+ END_OF_TRAINING = 6
1228+ END_OF_DATA = 7
1229+ GRADIENT_OVERFLOW = 8
1230+ EARLY_STOP = 9
1231+ STOP_REQUESTED = 10
1232+ STOPPED = 11
12041233
12051234
12061235T = TypeVar ("T" )
0 commit comments