Skip to content

Commit 513d50d

Browse files
Add a percentage correct metric for BC training (#402)
This patch adds a percentage correct for BC training, which makes it a lot easier to interpret how a model is doing rather than just staring at loss values.
1 parent fffde33 commit 513d50d

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

compiler_opt/rl/trainer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from compiler_opt.rl import random_net_distillation
2424
from tf_agents.agents import tf_agent
2525
from tf_agents.policies import policy_loader
26+
from tf_agents import trajectories
2627

2728
from tf_agents.utils import common as common_utils
2829
from typing import Optional
@@ -54,7 +55,8 @@ def __init__(
5455
log_interval=100,
5556
summary_log_interval=100,
5657
summary_export_interval=1000,
57-
summaries_flush_secs=10):
58+
summaries_flush_secs=10,
59+
bc_percentage_correct=False):
5860
"""Initialize the Trainer object.
5961
6062
Args:
@@ -70,6 +72,9 @@ def __init__(
7072
summary_export_interval: int, the training step interval for exporting
7173
to tensorboard.
7274
summaries_flush_secs: int, the seconds for flushing to tensorboard.
75+
bc_percentage_correct: bool, whether or not to log the accuracy of the
76+
current batch. This is intended for use during BC training where labels
77+
for the "correct" decision are available.
7378
"""
7479
self._root_dir = root_dir
7580
self._agent = agent
@@ -84,6 +89,7 @@ def __init__(
8489
self._summary_writer.set_as_default()
8590

8691
self._global_step = tf.compat.v1.train.get_or_create_global_step()
92+
self._bc_percentage_correct = bc_percentage_correct
8793

8894
# Initialize agent and trajectory replay.
8995
# Wrap training and trajectory replay in a tf.function to make it much
@@ -118,6 +124,7 @@ def _initialize_metrics(self):
118124
self._data_action_mean = tf.keras.metrics.Mean()
119125
self._data_reward_mean = tf.keras.metrics.Mean()
120126
self._num_trajectories = tf.keras.metrics.Sum()
127+
self._percentage_correct = tf.keras.metrics.Accuracy()
121128

122129
def _update_metrics(self, experience, monitor_dict):
123130
"""Updates metrics and exports to Tensorboard."""
@@ -130,6 +137,16 @@ def _update_metrics(self, experience, monitor_dict):
130137
experience.reward, sample_weight=is_action)
131138
self._num_trajectories.update_state(experience.is_first())
132139

140+
# Compute the accuracy if we are BC training.
141+
if self._bc_percentage_correct:
142+
experience_time_step = trajectories.TimeStep(experience.step_type,
143+
experience.reward,
144+
experience.discount,
145+
experience.observation)
146+
policy_actions = self._agent.policy.action(experience_time_step)
147+
self._percentage_correct.update_state(experience.action,
148+
policy_actions.action)
149+
133150
# Check earlier rather than later if we should record summaries.
134151
# TF also checks it, but much later. Needed to avoid looping through
135152
# the dict so gave the if a bigger scope
@@ -147,6 +164,11 @@ def _update_metrics(self, experience, monitor_dict):
147164
name='num_trajectories',
148165
data=self._num_trajectories.result(),
149166
step=self._global_step)
167+
if self._bc_percentage_correct:
168+
tf.summary.scalar(
169+
name='percentage_correct',
170+
data=self._percentage_correct.result(),
171+
step=self._global_step)
150172

151173
for name_scope, d in monitor_dict.items():
152174
with tf.name_scope(name_scope + '/'):
@@ -159,6 +181,7 @@ def _update_metrics(self, experience, monitor_dict):
159181
def _reset_metrics(self):
160182
"""Reset num_trajectories."""
161183
self._num_trajectories.reset_states()
184+
self._percentage_correct.reset_state()
162185

163186
def _log_experiment(self, loss):
164187
"""Log training info."""
@@ -204,6 +227,8 @@ def train(self, dataset_iter, monitor_dict, num_iterations: int):
204227

205228
loss = self._agent.train(experience)
206229

230+
self._percentage_correct.reset_state()
231+
207232
self._update_metrics(experience, monitor_dict)
208233
self._log_experiment(loss.loss)
209234
self._save_checkpoint()

compiler_opt/rl/trainer_test.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import tensorflow as tf
2020
from tf_agents.agents.behavioral_cloning import behavioral_cloning_agent
21-
from tf_agents.networks import q_rnn_network
21+
from tf_agents.networks import q_network
2222
from tf_agents.specs import tensor_spec
2323
from tf_agents.trajectories import time_step
2424
from tf_agents.trajectories import trajectory
@@ -66,10 +66,9 @@ def setUp(self):
6666
minimum=0,
6767
maximum=1,
6868
name='inlining_decision')
69-
self._network = q_rnn_network.QRnnNetwork(
69+
self._network = q_network.QNetwork(
7070
input_tensor_spec=self._time_step_spec.observation,
7171
action_spec=self._action_spec,
72-
lstm_size=(40,),
7372
preprocessing_layers={
7473
'callee_users': tf.keras.layers.Lambda(lambda x: x)
7574
})
@@ -154,6 +153,25 @@ def test_training_metrics(self):
154153
self.assertEqual(2, test_trainer._data_reward_mean.result().numpy())
155154
self.assertEqual(90, test_trainer._num_trajectories.result().numpy())
156155

156+
def test_training_metrics_bc(self):
157+
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
158+
self._time_step_spec,
159+
self._action_spec,
160+
self._network,
161+
tf.compat.v1.train.AdamOptimizer(),
162+
num_outer_dims=2)
163+
test_trainer = trainer.Trainer(
164+
root_dir=self.get_temp_dir(),
165+
agent=test_agent,
166+
summary_log_interval=1,
167+
bc_percentage_correct=True)
168+
169+
dataset_iter = _create_test_data(batch_size=3, sequence_length=3)
170+
monitor_dict = {'default': {'test': 1}}
171+
test_trainer.train(dataset_iter, monitor_dict, num_iterations=10)
172+
173+
self.assertLess(0.1, test_trainer._percentage_correct.result().numpy())
174+
157175
def test_inference(self):
158176
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
159177
self._time_step_spec,

0 commit comments

Comments
 (0)