@@ -149,10 +149,8 @@ def create_checkpoint_manager(self, max_to_keep=10, **kwargs):
149149 with self .strategy .scope ():
150150 self .ckpt = tf .train .Checkpoint (steps = self .steps , ** kwargs )
151151 checkpoint_dir = os .path .join (self .config .outdir , "checkpoints" )
152- if not os .path .exists (checkpoint_dir ):
153- os .makedirs (checkpoint_dir )
154- self .ckpt_manager = tf .train .CheckpointManager (
155- self .ckpt , checkpoint_dir , max_to_keep = max_to_keep )
152+ if not os .path .exists (checkpoint_dir ): os .makedirs (checkpoint_dir )
153+ self .ckpt_manager = tf .train .CheckpointManager (self .ckpt , checkpoint_dir , max_to_keep = max_to_keep )
156154
157155 def save_checkpoint (self ):
158156 """Save checkpoint."""
@@ -191,9 +189,11 @@ def run(self):
191189 while not self ._finished ():
192190 self ._train_epoch ()
193191
194- # save when training is done
192+ # save and evaluate when training is done
195193 self .save_checkpoint ()
196194 self .save_model_weights ()
195+ self .log_train_metrics ()
196+ self ._eval_epoch ()
197197
198198 self .train_progbar .close ()
199199 print ("> Finish training" )
@@ -221,8 +221,7 @@ def _train_epoch(self):
221221 self ._check_save_interval ()
222222
223223 # Print epoch info
224- self .train_progbar .set_description_str (
225- f"[Train] [Epoch { self .epochs } /{ self .config .num_epochs } ]" )
224+ self .train_progbar .set_description_str (f"[Train] [Epoch { self .epochs } /{ self .config .num_epochs } ]" )
226225
227226 # Print train info to progress bar
228227 self ._print_train_metrics (self .train_progbar )
@@ -313,40 +312,36 @@ def fit(self, train_dataset, eval_dataset=None, train_bs=None, train_acs=None, e
313312
314313 # -------------------------------- LOGGING -------------------------------------
315314
315+ def log_train_metrics (self ):
316+ self ._write_to_tensorboard (self .train_metrics , self .steps , stage = "train" )
317+ """Reset train metrics after save it to tensorboard."""
318+ for metric in self .train_metrics .keys ():
319+ self .train_metrics [metric ].reset_states ()
320+
316321 def _check_log_interval (self ):
317322 """Save log interval."""
318- if (self .steps % self .config .log_interval_steps == 0 ) or \
319- (self .total_train_steps and self .steps >= self .total_train_steps ):
320- self ._write_to_tensorboard (self .train_metrics , self .steps , stage = "train" )
321- """Reset train metrics after save it to tensorboard."""
322- for metric in self .train_metrics .keys ():
323- self .train_metrics [metric ].reset_states ()
323+ if (self .steps .numpy () % self .config .log_interval_steps == 0 ):
324+ self .log_train_metrics ()
324325
325326 def _check_save_interval (self ):
326327 """Save log interval."""
327- if (self .steps % self .config .save_interval_steps == 0 ) or \
328- (self .total_train_steps and self .steps >= self .total_train_steps ):
328+ if (self .steps .numpy () % self .config .save_interval_steps == 0 ):
329329 self .save_checkpoint ()
330330 self .save_model_weights ()
331331
332332 def _check_eval_interval (self ):
333333 """Save log interval."""
334- if (self .steps % self .config .eval_interval_steps == 0 ): # or \
335- # (self.total_train_steps and self.steps >= self.total_train_steps):
334+ if (self .steps .numpy () % self .config .eval_interval_steps == 0 ):
336335 self ._eval_epoch ()
337336
338337 # -------------------------------- UTILS -------------------------------------
339338
340339 def _print_train_metrics (self , progbar ):
341- result_dict = {}
342- for key , value in self .train_metrics .items ():
343- result_dict [f"{ key } " ] = str (value .result ().numpy ())
340+ result_dict = {key : str (value .result ().numpy ()) for key , value in self .train_metrics .items ()}
344341 progbar .set_postfix (result_dict )
345342
346343 def _print_eval_metrics (self , progbar ):
347- result_dict = {}
348- for key , value in self .eval_metrics .items ():
349- result_dict [f"{ key } " ] = str (value .result ().numpy ())
344+ result_dict = {key : str (value .result ().numpy ()) for key , value in self .eval_metrics .items ()}
350345 progbar .set_postfix (result_dict )
351346
352347 # -------------------------------- END -------------------------------------
0 commit comments