@@ -103,7 +103,11 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
103103 self ._model_in_training = tl .Serial (model , task .loss_layer )
104104 self ._eval_model = model if eval_model is None else eval_model
105105 self ._eval_task = eval_task
106+ self ._rjust_len = max ([0 ] + [len (name ) for name in eval_task .metric_names ])
107+
106108 self ._output_dir = os .path .expanduser (output_dir ) if output_dir else None
109+ if output_dir is not None :
110+ tf .io .gfile .makedirs (output_dir )
107111 default_fn = _at_step_1_and_periodically_at (task .n_steps_per_checkpoint )
108112 self ._checkpoint_at = checkpoint_at or default_fn
109113 self ._eval_at = eval_at or default_fn
@@ -120,9 +124,10 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
120124 _ , _ = task .optimizer .tree_init (self ._model_in_training .weights )
121125
122126 self ._gradients_and_state_fn = (
123- fastmath .jit (fastmath .grad (self ._model_in_training .pure_fn ,
124- argnums = 1 , # arg1 of pure_fn: weights
125- has_aux = True ))) # return (gradients, state)
127+ fastmath .jit (fastmath .value_and_grad (
128+ self ._model_in_training .pure_fn ,
129+ argnums = 1 , # arg1 of pure_fn: weights
130+ has_aux = True ))) # return (loss, state), gradients
126131
127132 if eval_task is not None :
128133 model_with_metrics = _model_with_metrics (self ._eval_model , eval_task )
@@ -142,13 +147,23 @@ def run(self, n_steps=1):
142147 weights = self ._model_in_training .weights
143148 state = self ._model_in_training .state
144149 slots = self ._task .optimizer .slots
150+ loss_acc , step_acc = 0.0 , 0
145151 for _ in range (n_steps ):
146152 self ._step += 1
147- weights , state , slots = self ._run_one_step (weights , state , slots )
153+ loss , weights , state , slots = self ._run_one_step (weights , state , slots )
154+ loss_acc += loss
155+ step_acc += 1
148156 if self ._eval_at (self ._step ):
149157 self ._model_in_training .weights = weights
150158 self ._model_in_training .state = state
151159 self ._eval_model .weights = self ._model .weights
160+ # TODO(lukaszkaiser): move this to a better place with other reporting
161+ loss_name = self ._task .loss_layer .name
162+ step_acc = max (1 , step_acc ) # only here do avoid potential divide-by-0
163+ self ._log_step ('%s %s | % .8f' % (
164+ 'train' .ljust (5 ), loss_name .rjust (self ._rjust_len ),
165+ loss_acc / float (step_acc )))
166+ loss_acc , step_acc = 0.0 , 0
152167 self .run_evals (weights , state )
153168 if self ._checkpoint_at (self ._step ):
154169 self .save_checkpoint (weights , state , slots )
@@ -199,11 +214,11 @@ def _run_one_step(self, weights, state, slots):
199214 opt_params = optimizer ._init_opt_params # pylint: disable=protected-access
200215 opt_params .update ({'learning_rate' : self ._task .learning_rate (step )})
201216
202- gradients , updated_state = (
217+ ( loss , updated_state ), gradients = (
203218 self ._gradients_and_state_fn (batch , weights , state , self .new_rng ()))
204219 updated_weights , updated_slots , _ = (
205220 optimizer .tree_update (step , gradients , weights , slots , opt_params ))
206- return updated_weights , updated_state , updated_slots
221+ return loss , updated_weights , updated_state , updated_slots
207222
208223 def run_evals (self , weights = None , state = None ):
209224 """Runs and records evals for this training session.
@@ -230,10 +245,9 @@ def run_evals(self, weights=None, state=None):
230245 self ._metrics_fn (batch , metrics_weights , metrics_state , rng ))
231246 sums += metric_values
232247 averages = sums / n_batches
233- rjust_len = max ([0 ] + [len (name ) for name in eval_task .metric_names ])
234248 for name , average_value in zip (eval_task .metric_names , averages ):
235249 self ._log_step ('%s %s | % .8f' % (
236- 'eval' .ljust (5 ), name .rjust (rjust_len ), average_value ))
250+ 'eval' .ljust (5 ), name .rjust (self . _rjust_len ), average_value ))
237251
238252 def _log_step (self , msg ):
239253 """Logs message, labeled with the current training step number."""
0 commit comments