Skip to content

Commit 0d094dc

Browse files
authored
Restore TensorBoard summary logging after TF 2 migration. (#326)
* Logs most of the previously logged scalars. * Bonus: wraps training in `timer.scoped`.
1 parent bce4600 commit 0d094dc

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

gematria/model/python/main_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,8 @@ def checkpoint_model():
826826
)
827827

828828
with train_summary_writer.as_default(), tf.summary.record_if(
829-
lambda: tf.math.equal(
830-
model.global_step % _GEMATRIA_SAVE_SUMMARIES_EPOCHS, 0
829+
lambda: tf.equal(
830+
model.global_step % _GEMATRIA_SAVE_SUMMARIES_EPOCHS.value, 0
831831
)
832832
):
833833
model.train(

gematria/model/python/model_base.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def __init__(
380380
def initialize(self) -> None:
381381
"""Initializes the model. Must be called before any other method."""
382382
self._create_optimizer()
383-
tf.summary.scalar('learning_rate', self._decayed_learning_rate)
384383

385384
@property
386385
def use_deltas(self) -> bool:
@@ -1294,15 +1293,17 @@ def run_one_epoch():
12941293
)
12951294
return self.train_batch(schedule)
12961295

1297-
for epoch_index in range(0, num_epochs):
1298-
stats = run_one_epoch()
1299-
logging.info('Training: %s', stats)
1300-
if not hooks:
1301-
continue
1302-
for epochs_every, hook_function in hooks:
1303-
if (epoch_index + 1) % epochs_every == 0:
1304-
hook_function()
1305-
return stats
1296+
with timer.scoped('ModelBase.train - one batch', num_iterations=num_epochs):
1297+
for epoch_index in range(num_epochs):
1298+
tf.summary.experimental.set_step(epoch_index)
1299+
stats = run_one_epoch()
1300+
logging.info('Training: %s', stats)
1301+
if not hooks:
1302+
continue
1303+
for epochs_every, hook_function in hooks:
1304+
if (epoch_index + 1) % epochs_every == 0:
1305+
hook_function()
1306+
return stats
13061307

13071308
def _compute_loss(self, schedule: FeedDict) -> loss_utils.LossComputation:
13081309
output = self(schedule, train=True)
@@ -1380,6 +1381,32 @@ def train_batch(
13801381

13811382
grads = tape.gradient(loss_tensor, variables)
13821383
grads_and_vars = zip(grads, variables)
1384+
1385+
# TODO(vbshah): Compute and log the number of steps per second as well.
1386+
tf.summary.scalar('learning_rate', self._decayed_learning_rate)
1387+
tf.summary.scalar('overall_loss', loss_tensor)
1388+
1389+
# TODO(vbshah): Consider writing delta loss summaries as well.
1390+
self._add_error_summaries('absolute_mse', loss.mean_squared_error)
1391+
self._add_error_summaries(
1392+
'relative_mae',
1393+
loss.mean_absolute_percentage_error,
1394+
)
1395+
self._add_error_summaries(
1396+
'relative_mse',
1397+
loss.mean_squared_percentage_error,
1398+
)
1399+
self._add_percentile_summaries(
1400+
'absolute_error',
1401+
self._collected_percentile_ranks,
1402+
loss.absolute_error_percentiles,
1403+
)
1404+
self._add_percentile_summaries(
1405+
'absolute_percentage_error',
1406+
self._collected_percentile_ranks,
1407+
loss.absolute_percentage_error_percentiles,
1408+
)
1409+
13831410
stats['loss'] = loss_tensor
13841411
stats['epoch'] = self.global_step
13851412
stats['absolute_mse'] = loss.mean_squared_error

0 commit comments

Comments
 (0)