Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions deeppavlov/core/trainers/nn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class NNTrainer(FitTrainer):
log_on_k_batches: count of random train batches to calculate metrics in log (default is ``1``)
max_test_batches: maximum batches count for pipeline testing and evaluation, overrides ``log_on_k_batches``,
ignored if negative (default is ``-1``)
always_save_model: if True, we always save the obtained weights of our model, regardless of the metric.
(default if ``False``)
**kwargs: additional parameters whose names will be logged but otherwise ignored


Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(self, chainer_config: dict, *,
validate_first: bool = True,
validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1,
log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1,
always_save_model: bool = False,
**kwargs) -> None:
super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets,
show_examples=show_examples, max_test_batches=max_test_batches, **kwargs)
Expand Down Expand Up @@ -141,6 +144,7 @@ def _improved(op):
self.max_epochs = epochs
self.epoch = start_epoch_num
self.max_batches = max_batches
self.always_save_model = always_save_model

self.train_batches_seen = 0
self.examples = 0
Expand Down Expand Up @@ -207,6 +211,11 @@ def _validate(self, iterator: DataLearningIterator,
self.score_best = score
log.info('Saving model')
self.save()
elif self.always_save_model:
log.info(f'Changed {m_name} from {self.score_best} to {score}')
self.score_best = score
log.info('But due to always_save_model, saving the model')
self.save()
else:
log.info('Did not improve on the {} of {}'.format(m_name, self.score_best))

Expand Down