Skip to content

Commit 43c3869

Browse files
authored
Merge pull request #1 from Raghava14/best_checkpoint
Added saving of best checkpoint
2 parents 60125c8 + 5754a2c commit 43c3869

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

onmt/models/model_saver.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, base_path, model, model_opt, fields, optim,
3737
if keep_checkpoint > 0:
3838
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
3939

40-
def save(self, step, moving_average=None):
40+
def save(self, step, is_best, moving_average=None):
4141
"""Main entry point for model saver
4242
4343
It wraps the `_save` method with checks and apply `keep_checkpoint`
@@ -54,7 +54,7 @@ def save(self, step, moving_average=None):
5454
model_params_data.append(param.data)
5555
param.data = avg.data
5656

57-
chkpt, chkpt_name = self._save(step, save_model)
57+
chkpt, chkpt_name = self._save(step, is_best, save_model)
5858
self.last_saved_step = step
5959

6060
if moving_average:
@@ -97,7 +97,7 @@ def _rm_checkpoint(self, name):
9797
class ModelSaver(ModelSaverBase):
9898
"""Simple model saver to filesystem"""
9999

100-
def _save(self, step, model):
100+
def _save(self, step, is_best, model):
101101
model_state_dict = model.state_dict()
102102
model_state_dict = {k: v for k, v in model_state_dict.items()
103103
if 'generator' not in k}
@@ -128,6 +128,10 @@ def _save(self, step, model):
128128
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
129129
checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
130130
torch.save(checkpoint, checkpoint_path)
131+
if is_best:
132+
logger.info("Obtained best checkpoint, saving...")
133+
best_checkpoint_path = '%s_best.pt' % (self.base_path)
134+
torch.save(checkpoint, best_checkpoint_path)
131135
return checkpoint, checkpoint_path
132136

133137
def _rm_checkpoint(self, name):

onmt/trainer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,17 @@ def train(self,
267267
if self.average_decay > 0 and i % self.average_every == 0:
268268
self._update_average(step)
269269

270+
perplexity = report_stats.ppl()
271+
if i == 0:
272+
best_perplexity = perplexity
273+
is_best = True
274+
else:
275+
if perplexity < best_perplexity:
276+
best_perplexity = perplexity
277+
is_best = True
278+
else:
279+
is_best = False
280+
270281
report_stats = self._maybe_report_training(
271282
step, train_steps,
272283
self.optim.learning_rate(),
@@ -297,13 +308,15 @@ def train(self,
297308
if (self.model_saver is not None
298309
and (save_checkpoint_steps != 0
299310
and step % save_checkpoint_steps == 0)):
300-
self.model_saver.save(step, moving_average=self.moving_average)
311+
self.model_saver.save(step, is_best,
312+
moving_average=self.moving_average)
301313

302314
if train_steps > 0 and step >= train_steps:
303315
break
304316

305317
if self.model_saver is not None:
306-
self.model_saver.save(step, moving_average=self.moving_average)
318+
self.model_saver.save(step, is_best,
319+
moving_average=self.moving_average)
307320
return total_stats
308321

309322
def validate(self, valid_iter, moving_average=None):

0 commit comments

Comments
 (0)