@@ -50,7 +50,8 @@ def __init__(
50
50
# Params for summaries and logging
51
51
checkpoint_interval = 10000 ,
52
52
log_interval = 100 ,
53
- summary_interval = 1000 ,
53
+ summary_log_interval = 100 ,
54
+ summary_export_interval = 1000 ,
54
55
summaries_flush_secs = 10 ):
55
56
"""Initialize the Trainer object.
56
57
@@ -62,16 +63,19 @@ def __init__(
62
63
checkpoint_interval: int, the training step interval for saving
63
64
checkpoint.
64
65
log_interval: int, the training step interval for logging.
65
- summary_interval: int, the training step interval for exporting to
66
- tensorboard.
66
+ summary_log_interval: the number of steps in between logging metrics
67
+ to tensorboard.
68
+ summary_export_interval: int, the training step interval for exporting
69
+ to tensorboard.
67
70
summaries_flush_secs: int, the seconds for flushing to tensorboard.
68
71
"""
69
72
self ._root_dir = root_dir
70
73
self ._agent = agent
71
74
self ._random_network_distillation = random_network_distillation
72
75
self ._checkpoint_interval = checkpoint_interval
73
76
self ._log_interval = log_interval
74
- self ._summary_interval = summary_interval
77
+ self ._summary_log_interval = summary_log_interval
78
+ self ._summary_export_interval = summary_export_interval
75
79
76
80
self ._summary_writer = tf .summary .create_file_writer (
77
81
self ._root_dir , flush_millis = summaries_flush_secs * 1000 )
@@ -108,6 +112,7 @@ def __init__(
108
112
self ._start_time = time .time ()
109
113
self ._last_checkpoint_step = 0
110
114
self ._last_log_step = 0
115
+ self ._summary_last_log_step = 0
111
116
112
117
def _initialize_metrics (self ):
113
118
"""Initializes metrics."""
@@ -117,35 +122,39 @@ def _initialize_metrics(self):
117
122
118
123
def _update_metrics (self , experience , monitor_dict ):
119
124
"""Updates metrics and exports to Tensorboard."""
120
- is_action = ~ experience .is_boundary ()
121
-
122
- self ._data_action_mean .update_state (
123
- experience .action , sample_weight = is_action )
124
- self ._data_reward_mean .update_state (
125
- experience .reward , sample_weight = is_action )
126
- self ._num_trajectories .update_state (experience .is_first ())
127
-
128
- with tf .name_scope ('default/' ):
129
- tf .summary .scalar (
130
- name = 'data_action_mean' ,
131
- data = self ._data_action_mean .result (),
132
- step = self ._global_step )
133
- tf .summary .scalar (
134
- name = 'data_reward_mean' ,
135
- data = self ._data_reward_mean .result (),
136
- step = self ._global_step )
137
- tf .summary .scalar (
138
- name = 'num_trajectories' ,
139
- data = self ._num_trajectories .result (),
140
- step = self ._global_step )
141
-
142
- for name_scope , d in monitor_dict .items ():
143
- with tf .name_scope (name_scope + '/' ):
144
- for key , value in d .items ():
145
- tf .summary .scalar (name = key , data = value , step = self ._global_step )
146
-
147
- tf .summary .histogram (
148
- name = 'reward' , data = experience .reward , step = self ._global_step )
125
+ if (self ._global_step .numpy () >=
126
+ self ._summary_last_log_step + self ._summary_log_interval ):
127
+ is_action = ~ experience .is_boundary ()
128
+
129
+ self ._data_action_mean .update_state (
130
+ experience .action , sample_weight = is_action )
131
+ self ._data_reward_mean .update_state (
132
+ experience .reward , sample_weight = is_action )
133
+ self ._num_trajectories .update_state (experience .is_first ())
134
+
135
+ with tf .name_scope ('default/' ):
136
+ tf .summary .scalar (
137
+ name = 'data_action_mean' ,
138
+ data = self ._data_action_mean .result (),
139
+ step = self ._global_step )
140
+ tf .summary .scalar (
141
+ name = 'data_reward_mean' ,
142
+ data = self ._data_reward_mean .result (),
143
+ step = self ._global_step )
144
+ tf .summary .scalar (
145
+ name = 'num_trajectories' ,
146
+ data = self ._num_trajectories .result (),
147
+ step = self ._global_step )
148
+
149
+ for name_scope , d in monitor_dict .items ():
150
+ with tf .name_scope (name_scope + '/' ):
151
+ for key , value in d .items ():
152
+ tf .summary .scalar (name = key , data = value , step = self ._global_step )
153
+
154
+ tf .summary .histogram (
155
+ name = 'reward' , data = experience .reward , step = self ._global_step )
156
+
157
+ self ._summary_last_log_step = self ._global_step .numpy ()
149
158
150
159
def _reset_metrics (self ):
151
160
"""Reset num_trajectories."""
@@ -176,8 +185,8 @@ def train(self, dataset_iter, monitor_dict, num_iterations):
176
185
self ._reset_metrics ()
177
186
# context management is implemented in decorator
178
187
# pylint: disable=not-context-manager
179
- with tf .summary .record_if (
180
- lambda : tf . math . equal ( self ._global_step % self ._summary_interval , 0 )):
188
+ with tf .summary .record_if (lambda : tf . math . equal (
189
+ self ._global_step % self ._summary_export_interval , 0 )):
181
190
for _ in range (num_iterations ):
182
191
# When the data is not enough to fill in a batch, next(dataset_iter)
183
192
# will throw StopIteration exception, logging a warning message instead
0 commit comments