@@ -114,9 +114,6 @@ def __init__(
114
114
self ._checkpointer .initialize_or_restore ()
115
115
116
116
self ._start_time = time .time ()
117
- self ._last_checkpoint_step = 0
118
- self ._last_log_step = 0
119
- self ._summary_last_log_step = 0
120
117
121
118
def _initialize_metrics (self ):
122
119
"""Initializes metrics."""
@@ -126,8 +123,7 @@ def _initialize_metrics(self):
126
123
127
124
def _update_metrics (self , experience , monitor_dict ):
128
125
"""Updates metrics and exports to Tensorboard."""
129
- if (self ._global_step .numpy () >=
130
- self ._summary_last_log_step + self ._summary_log_interval ):
126
+ if tf .math .equal (self ._global_step % self ._summary_log_interval , 0 ):
131
127
is_action = ~ experience .is_boundary ()
132
128
133
129
self ._data_action_mean .update_state (
@@ -136,6 +132,10 @@ def _update_metrics(self, experience, monitor_dict):
136
132
experience .reward , sample_weight = is_action )
137
133
self ._num_trajectories .update_state (experience .is_first ())
138
134
135
+ # Check earlier rather than later if we should record summaries.
136
+ # TF also checks it, but much later. Needed to avoid looping through
137
+ # the dict so gave the if a bigger scope
138
+ if tf .summary .should_record_summaries ():
139
139
with tf .name_scope ('default/' ):
140
140
tf .summary .scalar (
141
141
name = 'data_action_mean' ,
@@ -158,28 +158,23 @@ def _update_metrics(self, experience, monitor_dict):
158
158
tf .summary .histogram (
159
159
name = 'reward' , data = experience .reward , step = self ._global_step )
160
160
161
- self ._summary_last_log_step = self ._global_step .numpy ()
162
-
163
161
def _reset_metrics (self ):
164
162
"""Reset num_trajectories."""
165
163
self ._num_trajectories .reset_states ()
166
164
167
165
def _log_experiment (self , loss ):
168
166
"""Log training info."""
169
- global_step_val = self ._global_step . numpy ()
170
- if global_step_val - self . _last_log_step > = self ._log_interval :
167
+ if tf . math . equal ( self ._global_step % self . _log_interval , 0 ):
168
+ global_step_val = self ._global_step . numpy ()
171
169
logging .info ('step = %d, loss = %g' , global_step_val , loss )
172
170
time_acc = time .time () - self ._start_time
173
- steps_per_sec = ( global_step_val - self ._last_log_step ) / time_acc
171
+ steps_per_sec = self ._log_interval / time_acc
174
172
logging .info ('%.3f steps/sec' , steps_per_sec )
175
- self ._last_log_step = global_step_val
176
173
self ._start_time = time .time ()
177
174
178
175
def _save_checkpoint (self ):
179
- if (self ._global_step .numpy () - self ._last_checkpoint_step >=
180
- self ._checkpoint_interval ):
176
+ if tf .math .equal (self ._global_step % self ._checkpoint_interval , 0 ):
181
177
self ._checkpointer .save (global_step = self ._global_step )
182
- self ._last_checkpoint_step = self ._global_step .numpy ()
183
178
184
179
def global_step_numpy (self ):
185
180
return self ._global_step .numpy ()
0 commit comments