Skip to content

Commit ac094ca

Browse files
Make the tensorboard data interval adjustable (google#62)
Currently, all of the tensorboard variables get updated every single iteration. This is quite costly in terms of performance. I'm seeing about a 25% performance improvement by just running this once every 100 iterations instead of running it every iteration. The data showing up in tensorboard will definitely be more sparse, but it should be good enough to see the overall trends that we're looking for.
1 parent 21c230c commit ac094ca

File tree

2 files changed

+45
-36
lines changed

2 files changed

+45
-36
lines changed

compiler_opt/rl/trainer.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def __init__(
5050
# Params for summaries and logging
5151
checkpoint_interval=10000,
5252
log_interval=100,
53-
summary_interval=1000,
53+
summary_log_interval=100,
54+
summary_export_interval=1000,
5455
summaries_flush_secs=10):
5556
"""Initialize the Trainer object.
5657
@@ -62,16 +63,19 @@ def __init__(
6263
checkpoint_interval: int, the training step interval for saving
6364
checkpoint.
6465
log_interval: int, the training step interval for logging.
65-
summary_interval: int, the training step interval for exporting to
66-
tensorboard.
66+
summary_log_interval: the number of steps in between logging metrics
67+
to tensorboard.
68+
summary_export_interval: int, the training step interval for exporting
69+
to tensorboard.
6770
summaries_flush_secs: int, the seconds for flushing to tensorboard.
6871
"""
6972
self._root_dir = root_dir
7073
self._agent = agent
7174
self._random_network_distillation = random_network_distillation
7275
self._checkpoint_interval = checkpoint_interval
7376
self._log_interval = log_interval
74-
self._summary_interval = summary_interval
77+
self._summary_log_interval = summary_log_interval
78+
self._summary_export_interval = summary_export_interval
7579

7680
self._summary_writer = tf.summary.create_file_writer(
7781
self._root_dir, flush_millis=summaries_flush_secs * 1000)
@@ -108,6 +112,7 @@ def __init__(
108112
self._start_time = time.time()
109113
self._last_checkpoint_step = 0
110114
self._last_log_step = 0
115+
self._summary_last_log_step = 0
111116

112117
def _initialize_metrics(self):
113118
"""Initializes metrics."""
@@ -117,35 +122,39 @@ def _initialize_metrics(self):
117122

118123
def _update_metrics(self, experience, monitor_dict):
119124
"""Updates metrics and exports to Tensorboard."""
120-
is_action = ~experience.is_boundary()
121-
122-
self._data_action_mean.update_state(
123-
experience.action, sample_weight=is_action)
124-
self._data_reward_mean.update_state(
125-
experience.reward, sample_weight=is_action)
126-
self._num_trajectories.update_state(experience.is_first())
127-
128-
with tf.name_scope('default/'):
129-
tf.summary.scalar(
130-
name='data_action_mean',
131-
data=self._data_action_mean.result(),
132-
step=self._global_step)
133-
tf.summary.scalar(
134-
name='data_reward_mean',
135-
data=self._data_reward_mean.result(),
136-
step=self._global_step)
137-
tf.summary.scalar(
138-
name='num_trajectories',
139-
data=self._num_trajectories.result(),
140-
step=self._global_step)
141-
142-
for name_scope, d in monitor_dict.items():
143-
with tf.name_scope(name_scope + '/'):
144-
for key, value in d.items():
145-
tf.summary.scalar(name=key, data=value, step=self._global_step)
146-
147-
tf.summary.histogram(
148-
name='reward', data=experience.reward, step=self._global_step)
125+
if (self._global_step.numpy() >=
126+
self._summary_last_log_step + self._summary_log_interval):
127+
is_action = ~experience.is_boundary()
128+
129+
self._data_action_mean.update_state(
130+
experience.action, sample_weight=is_action)
131+
self._data_reward_mean.update_state(
132+
experience.reward, sample_weight=is_action)
133+
self._num_trajectories.update_state(experience.is_first())
134+
135+
with tf.name_scope('default/'):
136+
tf.summary.scalar(
137+
name='data_action_mean',
138+
data=self._data_action_mean.result(),
139+
step=self._global_step)
140+
tf.summary.scalar(
141+
name='data_reward_mean',
142+
data=self._data_reward_mean.result(),
143+
step=self._global_step)
144+
tf.summary.scalar(
145+
name='num_trajectories',
146+
data=self._num_trajectories.result(),
147+
step=self._global_step)
148+
149+
for name_scope, d in monitor_dict.items():
150+
with tf.name_scope(name_scope + '/'):
151+
for key, value in d.items():
152+
tf.summary.scalar(name=key, data=value, step=self._global_step)
153+
154+
tf.summary.histogram(
155+
name='reward', data=experience.reward, step=self._global_step)
156+
157+
self._summary_last_log_step = self._global_step.numpy()
149158

150159
def _reset_metrics(self):
151160
"""Reset num_trajectories."""
@@ -176,8 +185,8 @@ def train(self, dataset_iter, monitor_dict, num_iterations):
176185
self._reset_metrics()
177186
# context management is implemented in decorator
178187
# pylint: disable=not-context-manager
179-
with tf.summary.record_if(
180-
lambda: tf.math.equal(self._global_step % self._summary_interval, 0)):
188+
with tf.summary.record_if(lambda: tf.math.equal(
189+
self._global_step % self._summary_export_interval, 0)):
181190
for _ in range(num_iterations):
182191
# When the data is not enough to fill in a batch, next(dataset_iter)
183192
# will throw StopIteration exception, logging a warning message instead

compiler_opt/rl/trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_training(self):
9292
tf.compat.v1.train.AdamOptimizer(),
9393
num_outer_dims=2)
9494
test_trainer = trainer.Trainer(
95-
root_dir=self.get_temp_dir(), agent=test_agent)
95+
root_dir=self.get_temp_dir(), agent=test_agent, summary_log_interval=1)
9696
self.assertEqual(0, test_trainer._global_step.numpy())
9797

9898
dataset_iter = _create_test_data(batch_size=3, sequence_length=3)

0 commit comments

Comments
 (0)