Skip to content

Commit 9915a6d

Browse files
authored
[BC]Weighted bc training (#416)
raining code for the BC-Max algorithm which includes the tensorflow code to train a new policy and to save it as a tf-policy and code to compute the re-weighting for the supervised learning problem. This required updates to SequenceExampleFeatureNames from generate_bc_trajectories_lib.
1 parent e7b7e1c commit 9915a6d

File tree

5 files changed

+771
-8
lines changed

5 files changed

+771
-8
lines changed

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ class SequenceExampleFeatureNames:
6262
"""Feature names for features that are always added to seq example."""
6363
action: str = 'action'
6464
reward: str = 'reward'
65+
loss: str = 'loss'
66+
regret: str = 'regret'
6567
module_name: str = 'module_name'
68+
horizon: str = 'horizon'
69+
label_name: str = 'label'
6670

6771

6872
def get_loss(seq_example: tf.train.SequenceExample,
@@ -631,7 +635,8 @@ def _partition_for_loss(self, seq_example: tf.train.SequenceExample,
631635
seq_loss = get_loss(seq_example)
632636

633637
label = bisect.bisect_right(partitions, seq_loss)
634-
horizon = len(seq_example.feature_lists.feature_list['action'].feature)
638+
horizon = len(seq_example.feature_lists.feature_list[
639+
SequenceExampleFeatureNames.action].feature)
635640
label_list = [label for _ in range(horizon)]
636641
add_feature_list(seq_example, label_list, label_name)
637642

@@ -640,7 +645,7 @@ def process_succeeded(
640645
succeeded: List[Tuple[List, List[str], int, float]],
641646
spec_name: str,
642647
partitions: List[float],
643-
label_name: str = 'label'
648+
label_name: str = SequenceExampleFeatureNames.label_name
644649
) -> Tuple[tf.train.SequenceExample, ProfilingDictValueType,
645650
ProfilingDictValueType]:
646651
seq_example_list = [exploration_res[0] for exploration_res in succeeded]
@@ -691,12 +696,13 @@ def _profiling_dict(
691696
"""
692697

693698
per_module_dict = {
694-
'module_name':
699+
SequenceExampleFeatureNames.module_name:
695700
module_name,
696-
'loss':
701+
SequenceExampleFeatureNames.loss:
697702
float(get_loss(feature_list)),
698-
'horizon':
699-
len(feature_list.feature_lists.feature_list['action'].feature),
703+
SequenceExampleFeatureNames.horizon:
704+
len(feature_list.feature_lists.feature_list[
705+
SequenceExampleFeatureNames.action].feature),
700706
}
701707
return per_module_dict
702708

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Module for training an inlining policy with imitation learning."""
16+
17+
from absl import app
18+
from absl import flags
19+
from absl import logging
20+
21+
import gin
22+
import json
23+
from compiler_opt.rl import policy_saver
24+
25+
from compiler_opt.rl.inlining import imitation_learning_config as config
26+
27+
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import TrainingWeights
28+
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import ImitationLearningTrainer
29+
from compiler_opt.rl.imitation_learning.weighted_bc_trainer_lib import WrapKerasModel
30+
31+
_TRAINING_DATA = flags.DEFINE_multi_string(
32+
'training_data', None, 'Training data for one step of BC-Max')
33+
_PROFILING_DATA = flags.DEFINE_multi_string(
34+
'profiling_data', None,
35+
('Paths to profile files for computing the TrainingWeights'
36+
'If specified the order for each pair of json files is'
37+
'comparator.json followed by eval.json and the number of'
38+
'files should always be even.'))
39+
_SAVE_MODEL_DIR = flags.DEFINE_string(
40+
'save_model_dir', None, 'Location to save the keras and TFAgents policies.')
41+
_GIN_FILES = flags.DEFINE_multi_string(
42+
'gin_files', [], 'List of paths to gin configuration files.')
43+
_GIN_BINDINGS = flags.DEFINE_multi_string(
44+
'gin_bindings', [],
45+
'Gin bindings to override the values set in the config files.')
46+
47+
48+
def train():
49+
training_weights = None
50+
if _PROFILING_DATA.value:
51+
if len(_PROFILING_DATA.value) % 2 != 0:
52+
raise ValueError('Profiling file paths should always be an even number.')
53+
training_weights = TrainingWeights()
54+
for i in range(len(_PROFILING_DATA.value) // 2):
55+
with open(
56+
_PROFILING_DATA.value[2 * i], encoding='utf-8') as comp_f, open(
57+
_PROFILING_DATA.value[2 * i + 1], encoding='utf-8') as eval_f:
58+
comparator_prof = json.load(comp_f)
59+
eval_prof = json.load(eval_f)
60+
training_weights.update_weights(
61+
comparator_profile=comparator_prof, policy_profile=eval_prof)
62+
trainer = ImitationLearningTrainer(
63+
save_model_dir=_SAVE_MODEL_DIR.value, training_weights=training_weights)
64+
trainer.train(filepaths=_TRAINING_DATA.value)
65+
if _SAVE_MODEL_DIR.value:
66+
keras_policy = trainer.get_policy()
67+
expected_signature, action_spec = config.get_input_signature()
68+
wrapped_keras_model = WrapKerasModel(
69+
keras_policy=keras_policy,
70+
time_step_spec=expected_signature,
71+
action_spec=action_spec)
72+
policy_dict = {'tf_agents_policy': wrapped_keras_model}
73+
saver = policy_saver.PolicySaver(policy_dict=policy_dict)
74+
saver.save(_SAVE_MODEL_DIR.value)
75+
76+
77+
def main(_):
78+
gin.parse_config_files_and_bindings(
79+
_GIN_FILES.value, _GIN_BINDINGS.value, skip_unknown=False)
80+
logging.info(gin.config_str())
81+
82+
train()
83+
84+
85+
if __name__ == '__main__':
86+
app.run(main)

0 commit comments

Comments
 (0)