Skip to content

Commit 935f5fa

Browse files
authored
[BC] Trainer fixes (#444)
Changes to weighted_bc_trainer_lib and weighted_bc_trainer to force only ```_train_step``` to execute in Graph mode. Fix a possible division by 0 in weights creation. Makes sure that only ```_train_step``` executes in Graph mode.
1 parent 0dfe312 commit 935f5fa

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def select_best_exploration(
825825
loaded_module_spec: corpus.LoadedModuleSpec,
826826
) -> tuple[tuple[int, ProfilingDictValueType, ProfilingDictValueType],
827827
tf.train.SequenceExample]:
828-
828+
logging.set_verbosity('info')
829829
num_calls = len(self._tf_policy_action)
830830
time_call_compiler = 0
831831
logging.info('Processing module: %s', loaded_module_spec.name)

compiler_opt/rl/imitation_learning/weighted_bc_trainer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@
1313
# limitations under the License.
1414
"""Module for training an inlining policy with imitation learning."""
1515

16-
from absl import app
17-
from absl import flags
18-
from absl import logging
16+
import json
1917

2018
import gin
21-
import json
22-
from compiler_opt.rl import policy_saver
19+
import tensorflow as tf
20+
from absl import app, flags, logging
2321

22+
from compiler_opt.rl import policy_saver
23+
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import (
24+
ImitationLearningTrainer,
25+
TrainingWeights,
26+
WrapKerasModel,
27+
)
2428
from compiler_opt.rl.inlining import imitation_learning_config as config
2529

26-
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import TrainingWeights
27-
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import ImitationLearningTrainer
28-
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import WrapKerasModel
29-
3030
_TRAINING_DATA = flags.DEFINE_multi_string(
3131
'training_data', None, 'Training data for one step of BC-Max')
3232
_PROFILING_DATA = flags.DEFINE_multi_string(
@@ -78,6 +78,8 @@ def main(_):
7878
_GIN_FILES.value, _GIN_BINDINGS.value, skip_unknown=False)
7979
logging.info(gin.config_str())
8080

81+
tf.compat.v1.enable_eager_execution() # pytype: disable=module-attr
82+
8183
train()
8284

8385

compiler_opt/rl/imitation_learning/weighted_bc_trainer_lib.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def update_weights(
177177
bucket_loss += np.maximum(prof[SequenceExampleFeatureNames.regret], 0)
178178
losses_per_bucket.append(bucket_loss)
179179
logging.info('Losses per bucket: %s', losses_per_bucket)
180-
losses_per_bucket_normalized = losses_per_bucket / np.max(
181-
np.abs(losses_per_bucket))
180+
losses_per_bucket_normalized = losses_per_bucket / (
181+
np.max(np.abs(losses_per_bucket)) + 1e-6)
182182
probs_t = self._get_exp_gradient_step(losses_per_bucket_normalized, 1.0)
183183
self._round += 1
184184
self._probs = (self._probs * (self._round - 1) + probs_t) / self._round
@@ -228,6 +228,7 @@ def __init__(
228228
self._trainig_weights = TrainingWeights()
229229
self._features_to_remove = features_to_remove
230230
self._global_step = 0
231+
self._is_model_init = False
231232

232233
observation_spec, action_spec = config.get_inlining_signature_spec()
233234
sequence_features = {
@@ -322,13 +323,12 @@ def load_dataset(self, filepaths: list[str]) -> tf.data.TFRecordDataset:
322323
self._make_feature_label, num_processors=self._num_processors))
323324
dataset = dataset.unbatch().shuffle(self._shuffle_size).batch(
324325
self._batch_size, drop_remainder=True) # 4194304
325-
dataset = dataset.apply(tf.data.experimental.ignore_errors())
326326

327327
return dataset
328328

329329
def _create_weights(self, labels, weights_arr):
330-
p_norm = min(weights_arr) # check that this should be min
331-
weights_arr = tf.map_fn(lambda x: p_norm / x, tf.constant(weights_arr))
330+
p_norm = tf.reduce_min(weights_arr)
331+
weights_arr = tf.math.divide(p_norm, weights_arr)
332332
int_labels = tf.cast(labels, tf.int32)
333333
return tf.gather(weights_arr, int_labels)
334334

@@ -365,6 +365,7 @@ def _update_metrics(self, y_true, y_pred, loss, weights):
365365
tf.summary.scalar(
366366
name=metric.name, data=metric.result(), step=self._global_step)
367367

368+
@tf.function
368369
def _train_step(self, example, label, weight_labels, weights_arr):
369370
y_true = label[:, 0]
370371
y_true = tf.reshape(y_true, [self._batch_size, 1])
@@ -381,10 +382,15 @@ def train(self, filepaths: list[str]):
381382
"""Train the model for number of the specified number of epochs."""
382383
dataset = self.load_dataset(filepaths)
383384
logging.info('Datasets loaded from %s', str(filepaths))
384-
input_shape = dataset.element_spec[0].shape[-1]
385-
self._initialize_model(input_shape=input_shape)
386-
self._initialize_metrics()
387-
for _ in range(self._epochs):
385+
input_shape = int(dataset.element_spec[0].shape[-1])
386+
if not self._is_model_init:
387+
self._initialize_model(input_shape=input_shape)
388+
self._initialize_metrics()
389+
self._is_model_init = True
390+
self._global_step = 0
391+
logging.info('Training started')
392+
for epoch in range(self._epochs):
393+
logging.info('Epoch %s', epoch)
388394
for metric in self._metrics:
389395
metric.reset_states()
390396
for step, (x_batch_train, y_batch_train) in enumerate(dataset):

0 commit comments

Comments
 (0)