Skip to content

Commit fffde33

Browse files
Add a test for trainer metrics (#399)
This patch adds a test for the metrics in the trainer class to ensure that they are actually set. This patch just directly inspects the state rather than trying to ensure that we log to tensorboard too, but this should be good enough, and is definitely better than what we had before (nothing).
1 parent e314df5 commit fffde33

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

compiler_opt/rl/trainer_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828

2929

3030
def _create_test_data(batch_size, sequence_length):
31+
# Use the value zero, which signals the beginning of a sequence, which
32+
# allows us to test the num_trajectories metric.
3133
test_trajectory = trajectory.Trajectory(
32-
step_type=tf.fill([batch_size, sequence_length], 1),
34+
step_type=tf.fill([batch_size, sequence_length], 0),
3335
observation={
3436
'callee_users':
3537
tf.fill([batch_size, sequence_length],
@@ -131,6 +133,27 @@ def test_training_with_multiple_times(self):
131133
test_trainer.train(dataset_iter, monitor_dict, num_iterations=10)
132134
self.assertEqual(20, test_trainer._global_step.numpy())
133135

136+
def test_training_metrics(self):
137+
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
138+
self._time_step_spec,
139+
self._action_spec,
140+
self._network,
141+
tf.compat.v1.train.AdadeltaOptimizer(),
142+
num_outer_dims=2)
143+
test_trainer = trainer.Trainer(
144+
root_dir=self.get_temp_dir(), agent=test_agent, summary_log_interval=1)
145+
self.assertEqual(0, test_trainer._data_action_mean.result().numpy())
146+
self.assertEqual(0, test_trainer._data_reward_mean.result().numpy())
147+
self.assertEqual(0, test_trainer._num_trajectories.result().numpy())
148+
149+
dataset_iter = _create_test_data(batch_size=3, sequence_length=3)
150+
monitor_dict = {'default': {'test': 1}}
151+
test_trainer.train(dataset_iter, monitor_dict, num_iterations=10)
152+
153+
self.assertEqual(1, test_trainer._data_action_mean.result().numpy())
154+
self.assertEqual(2, test_trainer._data_reward_mean.result().numpy())
155+
self.assertEqual(90, test_trainer._num_trajectories.result().numpy())
156+
134157
def test_inference(self):
135158
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
136159
self._time_step_spec,

0 commit comments

Comments
 (0)