@@ -177,8 +177,8 @@ def update_weights(
177
177
bucket_loss += np .maximum (prof [SequenceExampleFeatureNames .regret ], 0 )
178
178
losses_per_bucket .append (bucket_loss )
179
179
logging .info ('Losses per bucket: %s' , losses_per_bucket )
180
- losses_per_bucket_normalized = losses_per_bucket / np . max (
181
- np .abs (losses_per_bucket ))
180
+ losses_per_bucket_normalized = losses_per_bucket / (
181
+ np .max ( np . abs (losses_per_bucket )) + 1e-6 )
182
182
probs_t = self ._get_exp_gradient_step (losses_per_bucket_normalized , 1.0 )
183
183
self ._round += 1
184
184
self ._probs = (self ._probs * (self ._round - 1 ) + probs_t ) / self ._round
@@ -228,6 +228,7 @@ def __init__(
228
228
self ._trainig_weights = TrainingWeights ()
229
229
self ._features_to_remove = features_to_remove
230
230
self ._global_step = 0
231
+ self ._is_model_init = False
231
232
232
233
observation_spec , action_spec = config .get_inlining_signature_spec ()
233
234
sequence_features = {
@@ -322,13 +323,12 @@ def load_dataset(self, filepaths: list[str]) -> tf.data.TFRecordDataset:
322
323
self ._make_feature_label , num_processors = self ._num_processors ))
323
324
dataset = dataset .unbatch ().shuffle (self ._shuffle_size ).batch (
324
325
self ._batch_size , drop_remainder = True ) # 4194304
325
- dataset = dataset .apply (tf .data .experimental .ignore_errors ())
326
326
327
327
return dataset
328
328
329
329
def _create_weights (self , labels , weights_arr ):
330
- p_norm = min (weights_arr ) # check that this should be min
331
- weights_arr = tf .map_fn ( lambda x : p_norm / x , tf . constant ( weights_arr ) )
330
+ p_norm = tf . reduce_min (weights_arr )
331
+ weights_arr = tf .math . divide ( p_norm , weights_arr )
332
332
int_labels = tf .cast (labels , tf .int32 )
333
333
return tf .gather (weights_arr , int_labels )
334
334
@@ -365,6 +365,7 @@ def _update_metrics(self, y_true, y_pred, loss, weights):
365
365
tf .summary .scalar (
366
366
name = metric .name , data = metric .result (), step = self ._global_step )
367
367
368
+ @tf .function
368
369
def _train_step (self , example , label , weight_labels , weights_arr ):
369
370
y_true = label [:, 0 ]
370
371
y_true = tf .reshape (y_true , [self ._batch_size , 1 ])
@@ -381,10 +382,15 @@ def train(self, filepaths: list[str]):
381
382
"""Train the model for number of the specified number of epochs."""
382
383
dataset = self .load_dataset (filepaths )
383
384
logging .info ('Datasets loaded from %s' , str (filepaths ))
384
- input_shape = dataset .element_spec [0 ].shape [- 1 ]
385
- self ._initialize_model (input_shape = input_shape )
386
- self ._initialize_metrics ()
387
- for _ in range (self ._epochs ):
385
+ input_shape = int (dataset .element_spec [0 ].shape [- 1 ])
386
+ if not self ._is_model_init :
387
+ self ._initialize_model (input_shape = input_shape )
388
+ self ._initialize_metrics ()
389
+ self ._is_model_init = True
390
+ self ._global_step = 0
391
+ logging .info ('Training started' )
392
+ for epoch in range (self ._epochs ):
393
+ logging .info ('Epoch %s' , epoch )
388
394
for metric in self ._metrics :
389
395
metric .reset_states ()
390
396
for step , (x_batch_train , y_batch_train ) in enumerate (dataset ):
0 commit comments