-
Notifications
You must be signed in to change notification settings - Fork 191
Stop criteria #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Stop criteria #79
Changes from 4 commits
a23650c
84bc0f8
e2b4b3b
4bb83d8
42acc60
114adfd
d8b56b0
72f4f7f
d4d9630
9e22191
9fc5caf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -212,11 +212,12 @@ def train(self): | |
|
|
||
| # This is the training loop. | ||
| step_time, train_loss = 0.0, 0.0 | ||
| current_step, num_iter_wo_improve = 0, 0 | ||
| prev_train_losses, prev_valid_losses = [], [] | ||
| num_iter_cover_train = int(sum(train_bucket_sizes) / | ||
| self.params.batch_size / | ||
| self.params.steps_per_checkpoint) | ||
| current_step, iter_inx, num_epochs_last_impr, max_num_epochs,\ | ||
|
||
| num_up_trends, num_down_trends = 0, 0, 0, 2, 0, 0 | ||
|
||
| prev_train_losses, prev_valid_losses, prev_epoch_valid_losses = [], [], [] | ||
| num_iter_cover_train = max(1, int(sum(train_bucket_sizes) / | ||
|
||
| self.params.batch_size / | ||
| self.params.steps_per_checkpoint)) | ||
| while (self.params.max_steps == 0 | ||
| or self.model.global_step.eval(self.session) | ||
| <= self.params.max_steps): | ||
|
|
@@ -232,45 +233,70 @@ def train(self): | |
| # Print statistics for the previous steps. | ||
| train_ppx = math.exp(train_loss) if train_loss < 300 else float('inf') | ||
| print ("global step %d learning rate %.4f step-time %.2f perplexity " | ||
| "%.2f" % (self.model.global_step.eval(self.session), | ||
| "%.3f" % (self.model.global_step.eval(self.session), | ||
| self.model.learning_rate.eval(self.session), | ||
| step_time, train_ppx)) | ||
| eval_loss = self.__calc_eval_loss() | ||
| eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') | ||
| print(" eval: perplexity %.2f" % (eval_ppx)) | ||
| print(" eval: perplexity %.3f" % (eval_ppx)) | ||
| # Decrease learning rate if no improvement was seen on train set | ||
| # over last 3 times. | ||
| if (len(prev_train_losses) > 2 | ||
| and train_loss > max(prev_train_losses[-3:])): | ||
| self.session.run(self.model.learning_rate_decay_op) | ||
|
|
||
| if (len(prev_valid_losses) > 0 | ||
| and eval_loss <= min(prev_valid_losses)): | ||
| # Save checkpoint and zero timer and loss. | ||
| self.model.saver.save(self.session, | ||
| os.path.join(self.model_dir, "model"), | ||
| write_meta_graph=False) | ||
|
|
||
| if (len(prev_valid_losses) > 0 | ||
| and eval_loss >= min(prev_valid_losses)): | ||
| num_iter_wo_improve += 1 | ||
| else: | ||
| num_iter_wo_improve = 0 | ||
|
|
||
| if num_iter_wo_improve > num_iter_cover_train * 2: | ||
| print("No improvement over last %d times. Training will stop after %d" | ||
| "iterations if no improvement was seen." | ||
| % (num_iter_wo_improve, | ||
| num_iter_cover_train - num_iter_wo_improve)) | ||
|
|
||
| # Stop train if no improvement was seen on validation set | ||
| # over last 3 epochs. | ||
| if num_iter_wo_improve > num_iter_cover_train * 3: | ||
| break | ||
| #if (len(prev_valid_losses) > 0 | ||
| # and eval_loss <= min(prev_valid_losses)): | ||
| # Save checkpoint and zero timer and loss. | ||
| self.model.saver.save(self.session, | ||
| os.path.join(self.model_dir, "model"), | ||
| write_meta_graph=False) | ||
|
|
||
| # After epoch pass, calculate average epoch loss | ||
| # and then make a decision to continue/stop training. | ||
| if (iter_inx > 0 | ||
| and iter_inx % num_iter_cover_train == 0): | ||
| # Calculate average validation loss during the previous epoch | ||
| epoch_eval_loss = self.__calc_epoch_loss( | ||
| prev_valid_losses[-num_iter_cover_train:]) | ||
| if len(prev_epoch_valid_losses) > 0: | ||
| print('Previous min epoch eval loss: %f, current epoch eval loss: %f' % | ||
| (min(prev_epoch_valid_losses), epoch_eval_loss)) | ||
| # Check if there was improvement during last epoch | ||
|
||
| if (epoch_eval_loss < min(prev_epoch_valid_losses)): | ||
| if num_epochs_last_impr > max_num_epochs/1.5: | ||
| max_num_epochs = int(1.5 * num_epochs_last_impr) | ||
|
||
| print('Improved during last epoch.') | ||
| prev_min_level = prev_epoch_valid_losses[-1] | ||
| num_epochs_last_impr, num_up_trends, num_down_trends = 0, 0, 0 | ||
| else: | ||
| print('No improvement during last epoch.') | ||
| num_epochs_last_impr += 1 | ||
|
||
| if (prev_epoch_valid_losses[-1] < epoch_eval_loss | ||
| and num_up_trends <= num_down_trends): | ||
| num_up_trends += 1 | ||
| elif (epoch_eval_loss < prev_epoch_valid_losses[-1] | ||
| and num_down_trends <= num_up_trends): | ||
| num_down_trends += 1 | ||
|
|
||
| print('Num up trends: %d, num down trends: %d' % | ||
| (num_up_trends, num_down_trends)) | ||
|
|
||
| print('Number of the epochs passed from the last improvement: %d' | ||
| % num_epochs_last_impr) | ||
| print('Max allowable number of epochs for improvement: %d' | ||
| % max_num_epochs) | ||
|
|
||
| if (num_epochs_last_impr > max_num_epochs | ||
| and num_up_trends > 1): | ||
| break | ||
|
|
||
| prev_epoch_valid_losses.append(round(epoch_eval_loss, 3)) | ||
|
|
||
| prev_train_losses.append(train_loss) | ||
| prev_valid_losses.append(eval_loss) | ||
| step_time, train_loss = 0.0, 0.0 | ||
| iter_inx += 1 | ||
|
|
||
| print('Training done.') | ||
| with tf.Graph().as_default(): | ||
|
|
@@ -316,6 +342,26 @@ def __calc_eval_loss(self): | |
| return eval_loss | ||
|
|
||
|
|
||
| def __calc_epoch_loss(self, epoch_losses): | ||
| """Calculate average loss during the epoch. | ||
|
||
|
|
||
| Args: | ||
| epoch_losses: list of the losses during the epoch; | ||
|
|
||
| Returns: | ||
| average value of the losses during the period; | ||
| """ | ||
| epoch_loss_sum, loss_num = 0, 0 | ||
| for loss in epoch_losses: | ||
| if loss < min(epoch_losses) * 1.5: | ||
| epoch_loss_sum += loss | ||
| loss_num += 1 | ||
| if loss_num > 0: | ||
| return epoch_loss_sum / loss_num | ||
| else: | ||
| return float(inf) | ||
|
|
||
|
|
||
| def decode_word(self, word): | ||
| """Decode input word to sequence of phonemes. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should work by default, not with options.