@@ -380,7 +380,6 @@ def __init__(
380380 def initialize (self ) -> None :
381381 """Initializes the model. Must be called before any other method."""
382382 self ._create_optimizer ()
383- tf .summary .scalar ('learning_rate' , self ._decayed_learning_rate )
384383
385384 @property
386385 def use_deltas (self ) -> bool :
@@ -1294,15 +1293,17 @@ def run_one_epoch():
12941293 )
12951294 return self .train_batch (schedule )
12961295
1297- for epoch_index in range (0 , num_epochs ):
1298- stats = run_one_epoch ()
1299- logging .info ('Training: %s' , stats )
1300- if not hooks :
1301- continue
1302- for epochs_every , hook_function in hooks :
1303- if (epoch_index + 1 ) % epochs_every == 0 :
1304- hook_function ()
1305- return stats
1296+ with timer .scoped ('ModelBase.train - one batch' , num_iterations = num_epochs ):
1297+ for epoch_index in range (num_epochs ):
1298+ tf .summary .experimental .set_step (epoch_index )
1299+ stats = run_one_epoch ()
1300+ logging .info ('Training: %s' , stats )
1301+ if not hooks :
1302+ continue
1303+ for epochs_every , hook_function in hooks :
1304+ if (epoch_index + 1 ) % epochs_every == 0 :
1305+ hook_function ()
1306+ return stats
13061307
13071308 def _compute_loss (self , schedule : FeedDict ) -> loss_utils .LossComputation :
13081309 output = self (schedule , train = True )
@@ -1380,6 +1381,32 @@ def train_batch(
13801381
13811382 grads = tape .gradient (loss_tensor , variables )
13821383 grads_and_vars = zip (grads , variables )
1384+
1385+ # TODO(vbshah): Compute and log the number of steps per second as well.
1386+ tf .summary .scalar ('learning_rate' , self ._decayed_learning_rate )
1387+ tf .summary .scalar ('overall_loss' , loss_tensor )
1388+
1389+ # TODO(vbshah): Consider writing delta loss summaries as well.
1390+ self ._add_error_summaries ('absolute_mse' , loss .mean_squared_error )
1391+ self ._add_error_summaries (
1392+ 'relative_mae' ,
1393+ loss .mean_absolute_percentage_error ,
1394+ )
1395+ self ._add_error_summaries (
1396+ 'relative_mse' ,
1397+ loss .mean_squared_percentage_error ,
1398+ )
1399+ self ._add_percentile_summaries (
1400+ 'absolute_error' ,
1401+ self ._collected_percentile_ranks ,
1402+ loss .absolute_error_percentiles ,
1403+ )
1404+ self ._add_percentile_summaries (
1405+ 'absolute_percentage_error' ,
1406+ self ._collected_percentile_ranks ,
1407+ loss .absolute_percentage_error_percentiles ,
1408+ )
1409+
13831410 stats ['loss' ] = loss_tensor
13841411 stats ['epoch' ] = self .global_step
13851412 stats ['absolute_mse' ] = loss .mean_squared_error
0 commit comments