Skip to content

Commit 67c8939

Browse files
committed
Introduce validate_at_start
1 parent 5136e5a commit 67c8939

File tree

6 files changed

+52
-19
lines changed

6 files changed

+52
-19
lines changed

src/fairseq2/nn/utils/mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def apply_mask(
2828
2929
:returns: The input sequences with mask applied. *Shape:* Same as ``seqs``.
3030
"""
31-
unsqueeze(mask, dim=-1, count=seqs.ndim - mask.ndim)
31+
mask = unsqueeze(mask, dim=-1, count=seqs.ndim - mask.ndim)
3232

3333
return seqs.where(mask, fill_value)
3434

@@ -95,7 +95,7 @@ def compute_row_mask(
9595
)
9696
else:
9797
# (N)
98-
row_lens = row_lens.view(num_rows)
98+
row_lens = row_lens.to(torch.int64).view(num_rows)
9999

100100
# We only mask rows that are longer than the mask span length.
101101
if (span_len >= row_lens).any():

src/fairseq2/recipes/_trainer.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ class Trainer(Recipe, Generic[BatchT]):
120120
_seed: int
121121
_max_num_steps: int | None
122122
_max_num_data_epochs: int | None
123-
_validator: Validator
123+
_validator: Validator | None
124+
_validate_at_start: bool
124125
_validate_after_n_steps: int
125126
_validate_every_n_steps: int | None
126127
_validate_after_n_data_epochs: int
@@ -181,6 +182,7 @@ def __init__(
181182
anomaly_detection: bool = False,
182183
max_num_steps: int | None = None,
183184
max_num_data_epochs: int | None = None,
185+
validate_at_start: bool = False,
184186
validate_after_n_steps: int = 0,
185187
validate_every_n_steps: int | None = None,
186188
validate_after_n_data_epochs: int = 0,
@@ -328,6 +330,8 @@ def __init__(
328330

329331
self._validator = validator
330332

333+
self._validate_at_start = validate_at_start
334+
331335
if validate_every_n_steps is not None:
332336
if validate_every_n_steps <= 0:
333337
raise ValueError(
@@ -507,6 +511,9 @@ def _maybe_restore_state(self) -> _TrainerState:
507511
) from ex
508512

509513
if step_nr is None:
514+
if self._validate_at_start:
515+
return _TrainerState.PRE_VALIDATION
516+
510517
return _TrainerState.DATA_LOAD
511518

512519
log.info("Restoring training from the last checkpoint at step {}.", step_nr)
@@ -574,6 +581,9 @@ def _do_run(self) -> None:
574581
with progress_task, self._lapse_watch:
575582
while self._state != _TrainerState.STOPPED:
576583
match self._state:
584+
case _TrainerState.PRE_VALIDATION:
585+
self._state = self._pre_validate()
586+
577587
case _TrainerState.DATA_LOAD:
578588
self._state = self._read_next_batches()
579589

@@ -607,6 +617,12 @@ def _do_run(self) -> None:
607617

608618
self._state = self._stop()
609619

620+
def _pre_validate(self) -> _TrainerState:
621+
if self._validate is not None:
622+
self._validate()
623+
624+
return _TrainerState.DATA_LOAD
625+
610626
def _read_next_batches(self) -> _TrainerState:
611627
with self._data_watch:
612628
try:
@@ -1045,6 +1061,9 @@ def _maybe_validate(self) -> float | None:
10451061
return score
10461062

10471063
def _should_validate(self) -> bool:
1064+
if self._validator is None:
1065+
return False
1066+
10481067
return self._should_do(
10491068
self._validate_after_n_steps,
10501069
self._validate_every_n_steps,
@@ -1053,7 +1072,13 @@ def _should_validate(self) -> bool:
10531072
)
10541073

10551074
def _validate(self) -> float | None:
1056-
log.info("Starting validation after step {}.", self._step_nr)
1075+
if self._validator is None:
1076+
raise InternalError("`_validator` is `None`.")
1077+
1078+
if self._step_nr == 0:
1079+
log.info("Starting pre-validation before training.")
1080+
else:
1081+
log.info("Starting validation after step {}.", self._step_nr)
10571082

10581083
self._model.module.eval()
10591084

@@ -1135,9 +1160,12 @@ def _should_do(
11351160
after_n_data_epochs: int,
11361161
every_n_data_epochs: int | None,
11371162
) -> bool:
1163+
if self._state == _TrainerState.PRE_VALIDATION:
1164+
return False
1165+
11381166
def should_do_at_step() -> bool:
11391167
if every_n_steps is not None:
1140-
if self._step_nr >= after_n_steps:
1168+
if self._step_nr > after_n_steps:
11411169
if self._step_nr % every_n_steps == 0:
11421170
return True
11431171

@@ -1158,7 +1186,7 @@ def should_do_at_step() -> bool:
11581186

11591187
if self._state == _TrainerState.END_OF_DATA_EPOCH:
11601188
if every_n_data_epochs is not None:
1161-
if self._data_epoch_nr >= after_n_data_epochs:
1189+
if self._data_epoch_nr > after_n_data_epochs:
11621190
if self._data_epoch_nr % every_n_data_epochs == 0:
11631191
already_done = should_do_at_step()
11641192

@@ -1191,16 +1219,17 @@ def step_nr(self) -> int:
11911219

11921220
class _TrainerState(Enum):
11931221
NOT_STARTED = 0
1194-
DATA_LOAD = 1
1195-
STEP = 2
1196-
POST_STEP = 3
1197-
END_OF_DATA_EPOCH = 4
1198-
END_OF_TRAINING = 5
1199-
END_OF_DATA = 6
1200-
GRADIENT_OVERFLOW = 7
1201-
EARLY_STOP = 8
1202-
STOP_REQUESTED = 9
1203-
STOPPED = 10
1222+
PRE_VALIDATION = 1
1223+
DATA_LOAD = 2
1224+
STEP = 3
1225+
POST_STEP = 4
1226+
END_OF_DATA_EPOCH = 5
1227+
END_OF_TRAINING = 6
1228+
END_OF_DATA = 7
1229+
GRADIENT_OVERFLOW = 8
1230+
EARLY_STOP = 9
1231+
STOP_REQUESTED = 10
1232+
STOPPED = 11
12041233

12051234

12061235
T = TypeVar("T")

src/fairseq2/recipes/_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,8 @@ def reset(self) -> None:
538538
@final
539539
class NoopValidator(Validator):
540540
@override
541-
def run(self, train_step_nr: int, train_data_epoch_nr: int) -> float:
542-
return -torch.inf
541+
def run(self, train_step_nr: int, train_data_epoch_nr: int) -> float | None:
542+
return None
543543

544544
@override
545545
def reset(self) -> None:

src/fairseq2/recipes/common/_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def create_trainer(
119119
max_num_steps=regime_section.num_steps,
120120
max_num_data_epochs=regime_section.num_data_epochs,
121121
validator=validator,
122+
validate_at_start=regime_section.validate_at_start,
122123
validate_after_n_steps=regime_section.validate_after_n_steps,
123124
validate_every_n_steps=regime_section.validate_every_n_steps,
124125
validate_after_n_data_epochs=regime_section.validate_after_n_data_epochs,

src/fairseq2/recipes/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ class RegimeSection:
234234
num_data_epochs: int | None = None
235235
"""The maximum number of data epochs to train for."""
236236

237+
validate_at_start: bool = False
238+
"""If ``True``, runs validation before starting training."""
239+
237240
validate_after_n_steps: int = 0
238241
"""The number of steps after which to start validating the model."""
239242

src/fairseq2/recipes/wav2vec2/asr/_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class Wav2Vec2AsrTrainConfig:
105105
regime: RegimeSection = field(
106106
default_factory=lambda: RegimeSection(
107107
num_steps=20_000,
108-
validate_after_n_steps=10_000,
108+
validate_after_n_steps=9999,
109109
validate_every_n_steps=1_000,
110110
publish_metrics_every_n_steps=200,
111111
)

0 commit comments

Comments
 (0)