Skip to content

Commit 6113267

Browse files
committed
🚀 small update on base trainer and release v0.6.4
1 parent 75ce851 commit 6113267

File tree

2 files changed

+19
-24
lines changed

2 files changed

+19
-24
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
setuptools.setup(
3737
name="TensorFlowASR",
38-
version="0.6.3",
38+
version="0.6.4",
3939
author="Huy Le Nguyen",
4040
author_email="[email protected]",
4141
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/runners/base_runners.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,8 @@ def create_checkpoint_manager(self, max_to_keep=10, **kwargs):
149149
with self.strategy.scope():
150150
self.ckpt = tf.train.Checkpoint(steps=self.steps, **kwargs)
151151
checkpoint_dir = os.path.join(self.config.outdir, "checkpoints")
152-
if not os.path.exists(checkpoint_dir):
153-
os.makedirs(checkpoint_dir)
154-
self.ckpt_manager = tf.train.CheckpointManager(
155-
self.ckpt, checkpoint_dir, max_to_keep=max_to_keep)
152+
if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir)
153+
self.ckpt_manager = tf.train.CheckpointManager(self.ckpt, checkpoint_dir, max_to_keep=max_to_keep)
156154

157155
def save_checkpoint(self):
158156
"""Save checkpoint."""
@@ -191,9 +189,11 @@ def run(self):
191189
while not self._finished():
192190
self._train_epoch()
193191

194-
# save when training is done
192+
# save and evaluate when training is done
195193
self.save_checkpoint()
196194
self.save_model_weights()
195+
self.log_train_metrics()
196+
self._eval_epoch()
197197

198198
self.train_progbar.close()
199199
print("> Finish training")
@@ -221,8 +221,7 @@ def _train_epoch(self):
221221
self._check_save_interval()
222222

223223
# Print epoch info
224-
self.train_progbar.set_description_str(
225-
f"[Train] [Epoch {self.epochs}/{self.config.num_epochs}]")
224+
self.train_progbar.set_description_str(f"[Train] [Epoch {self.epochs}/{self.config.num_epochs}]")
226225

227226
# Print train info to progress bar
228227
self._print_train_metrics(self.train_progbar)
@@ -313,40 +312,36 @@ def fit(self, train_dataset, eval_dataset=None, train_bs=None, train_acs=None, e
313312

314313
# -------------------------------- LOGGING -------------------------------------
315314

315+
def log_train_metrics(self):
316+
self._write_to_tensorboard(self.train_metrics, self.steps, stage="train")
317+
"""Reset train metrics after save it to tensorboard."""
318+
for metric in self.train_metrics.keys():
319+
self.train_metrics[metric].reset_states()
320+
316321
def _check_log_interval(self):
317322
"""Save log interval."""
318-
if (self.steps % self.config.log_interval_steps == 0) or \
319-
(self.total_train_steps and self.steps >= self.total_train_steps):
320-
self._write_to_tensorboard(self.train_metrics, self.steps, stage="train")
321-
"""Reset train metrics after save it to tensorboard."""
322-
for metric in self.train_metrics.keys():
323-
self.train_metrics[metric].reset_states()
323+
if (self.steps.numpy() % self.config.log_interval_steps == 0):
324+
self.log_train_metrics()
324325

325326
def _check_save_interval(self):
326327
"""Save log interval."""
327-
if (self.steps % self.config.save_interval_steps == 0) or \
328-
(self.total_train_steps and self.steps >= self.total_train_steps):
328+
if (self.steps.numpy() % self.config.save_interval_steps == 0):
329329
self.save_checkpoint()
330330
self.save_model_weights()
331331

332332
def _check_eval_interval(self):
333333
"""Save log interval."""
334-
if (self.steps % self.config.eval_interval_steps == 0): # or \
335-
# (self.total_train_steps and self.steps >= self.total_train_steps):
334+
if (self.steps.numpy() % self.config.eval_interval_steps == 0):
336335
self._eval_epoch()
337336

338337
# -------------------------------- UTILS -------------------------------------
339338

340339
def _print_train_metrics(self, progbar):
341-
result_dict = {}
342-
for key, value in self.train_metrics.items():
343-
result_dict[f"{key}"] = str(value.result().numpy())
340+
result_dict = {key: str(value.result().numpy()) for key, value in self.train_metrics.items()}
344341
progbar.set_postfix(result_dict)
345342

346343
def _print_eval_metrics(self, progbar):
347-
result_dict = {}
348-
for key, value in self.eval_metrics.items():
349-
result_dict[f"{key}"] = str(value.result().numpy())
344+
result_dict = {key: str(value.result().numpy()) for key, value in self.eval_metrics.items()}
350345
progbar.set_postfix(result_dict)
351346

352347
# -------------------------------- END -------------------------------------

0 commit comments

Comments
 (0)