Skip to content

Commit e5f5ddd

Browse files
authored
Fix metric logging (#72)
- Decouples the summary intervals, so one is not reliant on the other - Switches to modulo instead of remembering the last step number, this is more robust, especially if training is restarted at some step which is not a multiple of the set intervals (which previously would cause no data to be written)
1 parent 9372dc1 commit e5f5ddd

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

compiler_opt/rl/trainer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,6 @@ def __init__(
114114
self._checkpointer.initialize_or_restore()
115115

116116
self._start_time = time.time()
117-
self._last_checkpoint_step = 0
118-
self._last_log_step = 0
119-
self._summary_last_log_step = 0
120117

121118
def _initialize_metrics(self):
122119
"""Initializes metrics."""
@@ -126,8 +123,7 @@ def _initialize_metrics(self):
126123

127124
def _update_metrics(self, experience, monitor_dict):
128125
"""Updates metrics and exports to Tensorboard."""
129-
if (self._global_step.numpy() >=
130-
self._summary_last_log_step + self._summary_log_interval):
126+
if tf.math.equal(self._global_step % self._summary_log_interval, 0):
131127
is_action = ~experience.is_boundary()
132128

133129
self._data_action_mean.update_state(
@@ -136,6 +132,10 @@ def _update_metrics(self, experience, monitor_dict):
136132
experience.reward, sample_weight=is_action)
137133
self._num_trajectories.update_state(experience.is_first())
138134

135+
# Check earlier rather than later if we should record summaries.
136+
# TF also checks it, but much later. Needed to avoid looping through
137+
# the dict so gave the if a bigger scope
138+
if tf.summary.should_record_summaries():
139139
with tf.name_scope('default/'):
140140
tf.summary.scalar(
141141
name='data_action_mean',
@@ -158,28 +158,23 @@ def _update_metrics(self, experience, monitor_dict):
158158
tf.summary.histogram(
159159
name='reward', data=experience.reward, step=self._global_step)
160160

161-
self._summary_last_log_step = self._global_step.numpy()
162-
163161
def _reset_metrics(self):
164162
"""Reset num_trajectories."""
165163
self._num_trajectories.reset_states()
166164

167165
def _log_experiment(self, loss):
168166
"""Log training info."""
169-
global_step_val = self._global_step.numpy()
170-
if global_step_val - self._last_log_step >= self._log_interval:
167+
if tf.math.equal(self._global_step % self._log_interval, 0):
168+
global_step_val = self._global_step.numpy()
171169
logging.info('step = %d, loss = %g', global_step_val, loss)
172170
time_acc = time.time() - self._start_time
173-
steps_per_sec = (global_step_val - self._last_log_step) / time_acc
171+
steps_per_sec = self._log_interval / time_acc
174172
logging.info('%.3f steps/sec', steps_per_sec)
175-
self._last_log_step = global_step_val
176173
self._start_time = time.time()
177174

178175
def _save_checkpoint(self):
179-
if (self._global_step.numpy() - self._last_checkpoint_step >=
180-
self._checkpoint_interval):
176+
if tf.math.equal(self._global_step % self._checkpoint_interval, 0):
181177
self._checkpointer.save(global_step=self._global_step)
182-
self._last_checkpoint_step = self._global_step.numpy()
183178

184179
def global_step_numpy(self):
185180
return self._global_step.numpy()

compiler_opt/rl/trainer_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,24 @@ def test_training(self):
9292
tf.compat.v1.train.AdamOptimizer(),
9393
num_outer_dims=2)
9494
test_trainer = trainer.Trainer(
95-
root_dir=self.get_temp_dir(), agent=test_agent, summary_log_interval=1)
95+
root_dir=self.get_temp_dir(),
96+
agent=test_agent,
97+
summary_log_interval=1,
98+
summary_export_interval=10)
9699
self.assertEqual(0, test_trainer._global_step.numpy())
97100

98101
dataset_iter = _create_test_data(batch_size=3, sequence_length=3)
99102
monitor_dict = {'default': {'test': 1}}
100103

101104
with mock.patch.object(
102105
tf.summary, 'scalar', autospec=True) as mock_scalar_summary:
103-
test_trainer.train(dataset_iter, monitor_dict, num_iterations=10)
106+
test_trainer.train(dataset_iter, monitor_dict, num_iterations=100)
104107
self.assertEqual(
105108
10,
106109
sum(1 for c in mock_scalar_summary.mock_calls
107110
if c[2]['name'] == 'test'))
108-
self.assertEqual(10, test_trainer._global_step.numpy())
111+
self.assertEqual(100, test_trainer._global_step.numpy())
112+
self.assertEqual(100, test_trainer.global_step_numpy())
109113

110114
def test_training_with_multiple_times(self):
111115
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(

0 commit comments

Comments
 (0)