|
28 | 28 |
|
29 | 29 |
|
30 | 30 | def _create_test_data(batch_size, sequence_length):
|
| 31 | + # Use the value zero, which signals the beginning of a sequence, which |
| 32 | + # allows us to test the num_trajectories metric. |
31 | 33 | test_trajectory = trajectory.Trajectory(
|
32 |
| - step_type=tf.fill([batch_size, sequence_length], 1), |
| 34 | + step_type=tf.fill([batch_size, sequence_length], 0), |
33 | 35 | observation={
|
34 | 36 | 'callee_users':
|
35 | 37 | tf.fill([batch_size, sequence_length],
|
@@ -131,6 +133,27 @@ def test_training_with_multiple_times(self):
|
131 | 133 | test_trainer.train(dataset_iter, monitor_dict, num_iterations=10)
|
132 | 134 | self.assertEqual(20, test_trainer._global_step.numpy())
|
133 | 135 |
|
| 136 | + def test_training_metrics(self): |
| 137 | + test_agent = behavioral_cloning_agent.BehavioralCloningAgent( |
| 138 | + self._time_step_spec, |
| 139 | + self._action_spec, |
| 140 | + self._network, |
| 141 | + tf.compat.v1.train.AdadeltaOptimizer(), |
| 142 | + num_outer_dims=2) |
| 143 | + test_trainer = trainer.Trainer( |
| 144 | + root_dir=self.get_temp_dir(), agent=test_agent, summary_log_interval=1) |
| 145 | + self.assertEqual(0, test_trainer._data_action_mean.result().numpy()) |
| 146 | + self.assertEqual(0, test_trainer._data_reward_mean.result().numpy()) |
| 147 | + self.assertEqual(0, test_trainer._num_trajectories.result().numpy()) |
| 148 | + |
| 149 | + dataset_iter = _create_test_data(batch_size=3, sequence_length=3) |
| 150 | + monitor_dict = {'default': {'test': 1}} |
| 151 | + test_trainer.train(dataset_iter, monitor_dict, num_iterations=10) |
| 152 | + |
| 153 | + self.assertEqual(1, test_trainer._data_action_mean.result().numpy()) |
| 154 | + self.assertEqual(2, test_trainer._data_reward_mean.result().numpy()) |
| 155 | + self.assertEqual(90, test_trainer._num_trajectories.result().numpy()) |
| 156 | + |
134 | 157 | def test_inference(self):
|
135 | 158 | test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
|
136 | 159 | self._time_step_spec,
|
|
0 commit comments