23
23
from compiler_opt .rl import random_net_distillation
24
24
from tf_agents .agents import tf_agent
25
25
from tf_agents .policies import policy_loader
26
+ from tf_agents import trajectories
26
27
27
28
from tf_agents .utils import common as common_utils
28
29
from typing import Optional
@@ -54,7 +55,8 @@ def __init__(
54
55
log_interval = 100 ,
55
56
summary_log_interval = 100 ,
56
57
summary_export_interval = 1000 ,
57
- summaries_flush_secs = 10 ):
58
+ summaries_flush_secs = 10 ,
59
+ bc_percentage_correct = False ):
58
60
"""Initialize the Trainer object.
59
61
60
62
Args:
@@ -70,6 +72,9 @@ def __init__(
70
72
summary_export_interval: int, the training step interval for exporting
71
73
to tensorboard.
72
74
summaries_flush_secs: int, the seconds for flushing to tensorboard.
75
+ bc_percentage_correct: bool, whether or not to log the accuracy of the
76
+ current batch. This is intended for use during BC training where labels
77
+ for the "correct" decision are available.
73
78
"""
74
79
self ._root_dir = root_dir
75
80
self ._agent = agent
@@ -84,6 +89,7 @@ def __init__(
84
89
self ._summary_writer .set_as_default ()
85
90
86
91
self ._global_step = tf .compat .v1 .train .get_or_create_global_step ()
92
+ self ._bc_percentage_correct = bc_percentage_correct
87
93
88
94
# Initialize agent and trajectory replay.
89
95
# Wrap training and trajectory replay in a tf.function to make it much
@@ -118,6 +124,7 @@ def _initialize_metrics(self):
118
124
self ._data_action_mean = tf .keras .metrics .Mean ()
119
125
self ._data_reward_mean = tf .keras .metrics .Mean ()
120
126
self ._num_trajectories = tf .keras .metrics .Sum ()
127
+ self ._percentage_correct = tf .keras .metrics .Accuracy ()
121
128
122
129
def _update_metrics (self , experience , monitor_dict ):
123
130
"""Updates metrics and exports to Tensorboard."""
@@ -130,6 +137,16 @@ def _update_metrics(self, experience, monitor_dict):
130
137
experience .reward , sample_weight = is_action )
131
138
self ._num_trajectories .update_state (experience .is_first ())
132
139
140
+ # Compute the accuracy if we are BC training.
141
+ if self ._bc_percentage_correct :
142
+ experience_time_step = trajectories .TimeStep (experience .step_type ,
143
+ experience .reward ,
144
+ experience .discount ,
145
+ experience .observation )
146
+ policy_actions = self ._agent .policy .action (experience_time_step )
147
+ self ._percentage_correct .update_state (experience .action ,
148
+ policy_actions .action )
149
+
133
150
# Check earlier rather than later if we should record summaries.
134
151
# TF also checks it, but much later. Needed to avoid looping through
135
152
# the dict so gave the if a bigger scope
@@ -147,6 +164,11 @@ def _update_metrics(self, experience, monitor_dict):
147
164
name = 'num_trajectories' ,
148
165
data = self ._num_trajectories .result (),
149
166
step = self ._global_step )
167
+ if self ._bc_percentage_correct :
168
+ tf .summary .scalar (
169
+ name = 'percentage_correct' ,
170
+ data = self ._percentage_correct .result (),
171
+ step = self ._global_step )
150
172
151
173
for name_scope , d in monitor_dict .items ():
152
174
with tf .name_scope (name_scope + '/' ):
@@ -159,6 +181,7 @@ def _update_metrics(self, experience, monitor_dict):
159
181
def _reset_metrics (self ):
160
182
"""Reset num_trajectories."""
161
183
self ._num_trajectories .reset_states ()
184
+ self ._percentage_correct .reset_state ()
162
185
163
186
def _log_experiment (self , loss ):
164
187
"""Log training info."""
@@ -204,6 +227,8 @@ def train(self, dataset_iter, monitor_dict, num_iterations: int):
204
227
205
228
loss = self ._agent .train (experience )
206
229
230
+ self ._percentage_correct .reset_state ()
231
+
207
232
self ._update_metrics (experience , monitor_dict )
208
233
self ._log_experiment (loss .loss )
209
234
self ._save_checkpoint ()
0 commit comments