Skip to content

Commit 7e94d9e

Browse files
authored
[BC] Inlining for size configs. (#410)
* Adding configs for collecting imitation learning tarjectories for inlining for size. * Removing values from imitation_learning.gin * Removed unused flags for gin bindings and gin configs.
1 parent 513d50d commit 7e94d9e

File tree

7 files changed

+177
-32
lines changed

7 files changed

+177
-32
lines changed

compiler_opt/rl/imitation_learning/generate_bc_trajectories.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,19 @@
2121
import gin
2222

2323
from compiler_opt.rl.imitation_learning import generate_bc_trajectories_lib
24-
from compiler_opt.tools import generate_test_model # pylint:disable=unused-import
2524

2625
from tf_agents.system import system_multiprocessing as multiprocessing
2726

28-
flags.FLAGS['gin_files'].allow_override = True
29-
flags.FLAGS['gin_bindings'].allow_override = True
30-
31-
FLAGS = flags.FLAGS
27+
_GIN_FILES = flags.DEFINE_multi_string(
28+
'gin_files', [], 'List of paths to gin configuration files.')
29+
_GIN_BINDINGS = flags.DEFINE_multi_string(
30+
'gin_bindings', [],
31+
'Gin bindings to override the values set in the config files.')
3232

3333

3434
def main(_):
3535
gin.parse_config_files_and_bindings(
36-
FLAGS.gin_files, bindings=FLAGS.gin_bindings, skip_unknown=True)
36+
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=True)
3737
logging.info(gin.config_str())
3838

3939
generate_bc_trajectories_lib.gen_trajectories()

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator, Union
2121
import json
2222

23-
from absl import flags
23+
# from absl import flags
2424
from absl import logging
2525
import bisect
2626
import dataclasses
@@ -46,13 +46,6 @@
4646
from compiler_opt.distributed import buffered_scheduler
4747
from compiler_opt.distributed.local import local_worker_manager
4848

49-
from compiler_opt.tools import generate_test_model # pylint:disable=unused-import
50-
51-
flags.FLAGS['gin_files'].allow_override = True
52-
flags.FLAGS['gin_bindings'].allow_override = True
53-
54-
FLAGS = flags.FLAGS
55-
5649
ProfilingDictValueType = Dict[str, Union[str, float, int]]
5750

5851

@@ -350,6 +343,7 @@ def __init__(
350343
task_type=mlgo_task_type,
351344
obs_spec=obs_spec,
352345
action_spec=action_spec,
346+
interactive_only=True,
353347
)
354348
if self._env.action_spec:
355349
if self._env.action_spec.dtype != tf.int64:
@@ -703,7 +697,7 @@ def _save_binary(self, base_path: str, save_path: str, binary_path: str):
703697
if not os.path.exists(save_dir):
704698
os.makedirs(save_dir, exist_ok=True)
705699
shutil.copy(
706-
os.path.join(binary_path, 'comp_binary'),
700+
os.path.join(binary_path, 'compiled_module'),
707701
os.path.join(save_dir, path_tail))
708702

709703

@@ -782,6 +776,7 @@ def __init__(
782776
self._obs_action_specs: Optional[Tuple[
783777
time_step.TimeStep, tensor_spec.BoundedTensorSpec]] = obs_action_specs
784778
self._mw_utility = ModuleWorkerResultProcessor(base_path)
779+
self._base_path = base_path
785780
self._partitions = partitions
786781
self._envargs = envargs
787782

@@ -860,7 +855,14 @@ def select_best_exploration(
860855
for temp_dirs in working_dir_list:
861856
for temp_dir in temp_dirs:
862857
temp_dir_head = os.path.split(temp_dir)[0]
863-
shutil.rmtree(temp_dir_head)
858+
try:
859+
shutil.rmtree(temp_dir_head)
860+
except FileNotFoundError as e:
861+
if not self._base_path:
862+
continue
863+
else:
864+
raise FileNotFoundError(
865+
f'Compilation directory {temp_dir_head} does not exist.') from e
864866

865867
return (
866868
num_calls,
@@ -878,6 +880,8 @@ def gen_trajectories(
878880
output_path: str = gin.REQUIRED,
879881
mlgo_task_type: Type[env.MLGOTask] = gin.REQUIRED,
880882
callable_policies: List[Optional[Callable[[Any], np.ndarray]]] = [],
883+
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
884+
bool]]] = None,
881885
obs_action_spec: Optional[Tuple[time_step.TimeStep,
882886
tensor_spec.BoundedTensorSpec]] = None,
883887
num_workers: Optional[int] = None,
@@ -898,6 +902,8 @@ def gen_trajectories(
898902
callable_policies: list of policies in the form of callable functions,
899903
this supplements the loaded policies from policy_paths given
900904
in ModuleWorker
905+
explore_on_features: dict of feature names and functions which specify
906+
when to explore on the respective feature
901907
obs_action_spec: optional observation and action spec annotating the state
902908
(TimeStep) for training a policy
903909
num_workers: number of distributed workers to process the corpus.
@@ -937,6 +943,7 @@ def gen_trajectories(
937943
obs_action_specs=obs_action_spec,
938944
mlgo_task_type=mlgo_task_type,
939945
callable_policies=callable_policies,
946+
explore_on_features=explore_on_features,
940947
gin_config_str=gin.config_str(),
941948
) as lwm:
942949

compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import functools
1818
from absl import app
19-
from absl import flags
2019
import gin
2120
import json
2221
from typing import List
@@ -36,8 +35,8 @@
3635
from compiler_opt.rl import env
3736
from compiler_opt.rl import env_test
3837

39-
flags.FLAGS['gin_files'].allow_override = True
40-
flags.FLAGS['gin_bindings'].allow_override = True
38+
# flags.FLAGS['gin_files'].allow_override = True
39+
# flags.FLAGS['gin_bindings'].allow_override = True
4140

4241
_eps = 1e-5
4342

@@ -648,17 +647,17 @@ def select_best_exploration(self, mock_popen, loaded_module_spec):
648647
class GenTrajectoriesTest(tf.test.TestCase):
649648

650649
def setUp(self):
651-
with gin.unlock_config():
652-
gin.parse_config_files_and_bindings(
653-
config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'],
654-
bindings=[
655-
('generate_bc_trajectories_test.'
656-
'MockModuleWorker.clang_path="/test/clang/path"'),
657-
('generate_bc_trajectories_test.'
658-
'MockModuleWorker.exploration_frac=1.0'),
659-
('generate_bc_trajectories_test.'
660-
'MockModuleWorker.reward_key="default"'),
661-
])
650+
with gin.config_scope('gen_trajectories_test'):
651+
with gin.unlock_config():
652+
gin.bind_parameter(
653+
'generate_bc_trajectories_test.MockModuleWorker.clang_path',
654+
'/test/clang/path')
655+
gin.bind_parameter(
656+
'generate_bc_trajectories_test.MockModuleWorker.exploration_frac',
657+
1.0)
658+
gin.bind_parameter(
659+
'generate_bc_trajectories_test.MockModuleWorker.reward_key',
660+
'default')
662661
return super().setUp()
663662

664663
def test_gen_trajectories(self):

compiler_opt/rl/inlining/env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def get_cmdline(self, clang_path: str, base_args: List[str],
4444
'-enable-ml-inliner=release',
4545
'-mllvm',
4646
f'-inliner-interactive-channel-base={interactive_base_path}',
47-
#'-mllvm',
48-
#'-inliner-interactive-include-default',
47+
'-mllvm',
48+
'-inliner-interactive-include-default',
4949
]
5050
else:
5151
interactive_args = []
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import compiler_opt.rl.inlining.env
2+
import compiler_opt.rl.inlining.imitation_learning_config
3+
import compiler_opt.rl.imitation_learning.generate_bc_trajectories_lib
4+
5+
6+
env.InliningForSizeTask.llvm_size_path=''
7+
8+
generate_bc_trajectories_lib.ModuleWorker.clang_path=''
9+
generate_bc_trajectories_lib.ModuleWorker.mlgo_task_type=@env.InliningForSizeTask
10+
generate_bc_trajectories_lib.ModuleWorker.policy_paths=['']
11+
generate_bc_trajectories_lib.ModuleWorker.exploration_policy_paths=[]
12+
generate_bc_trajectories_lib.ModuleWorker.explore_on_features=None
13+
generate_bc_trajectories_lib.ModuleWorker.base_path=''
14+
generate_bc_trajectories_lib.ModuleWorker.partitions=[
15+
285.0, 376.0, 452.0, 512.0, 571.0, 627.5, 720.0, 809.5, 1304.0, 1832.0,
16+
2467.0, 3344.0, 4545.0, 6459.0, 9845.0, 17953.0, 29430.5, 85533.5,
17+
124361.0]
18+
generate_bc_trajectories_lib.ModuleWorker.reward_key='default'
19+
# generate_bc_trajectories_lib.ModuleWorker.gin_config_str=None
20+
21+
generate_bc_trajectories_lib.gen_trajectories.data_path=''
22+
generate_bc_trajectories_lib.gen_trajectories.delete_flags=('-split-dwarf-file', '-split-dwarf-output')
23+
generate_bc_trajectories_lib.gen_trajectories.output_file_name=''
24+
generate_bc_trajectories_lib.gen_trajectories.output_path=''
25+
generate_bc_trajectories_lib.gen_trajectories.mlgo_task_type=@imitation_learning_config.get_task_type()
26+
generate_bc_trajectories_lib.gen_trajectories.obs_action_spec=@imitation_learning_config.get_inlining_signature_spec()
27+
generate_bc_trajectories_lib.gen_trajectories.num_workers=1
28+
generate_bc_trajectories_lib.gen_trajectories.num_output_files=1
29+
generate_bc_trajectories_lib.gen_trajectories.profiling_file_path=''
30+
generate_bc_trajectories_lib.gen_trajectories.worker_wait_sec=100
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Module for collect data of inlining-for-size."""
16+
17+
import gin
18+
from typing import Type
19+
20+
import numpy as np
21+
import tensorflow as tf
22+
from tf_agents.trajectories import time_step
23+
24+
from compiler_opt.rl.inlining import config
25+
from compiler_opt.rl.inlining import env
26+
27+
28+
@gin.register
29+
def get_inlining_signature_spec():
30+
"""Returns (time_step_spec, action_spec) for collecting IL trajectories."""
31+
time_step_spec, action_spec = config.get_inlining_signature_spec()
32+
observation_spec = time_step_spec.observation
33+
observation_spec.update(
34+
dict((key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key)) for key in (
35+
'is_callee_avail_external',
36+
'is_caller_avail_external',
37+
# inlining_default is not used as feature in training.
38+
'inlining_default')))
39+
40+
time_step_spec = time_step.time_step_spec(observation_spec,
41+
time_step_spec.reward)
42+
43+
return time_step_spec, action_spec
44+
45+
46+
@gin.register
47+
def get_task_type() -> Type[env.InliningForSizeTask]:
48+
"""Returns the task type for the trajectory collection."""
49+
return env.InliningForSizeTask
50+
51+
52+
@gin.register
53+
def greedy_policy(state: time_step.TimeStep):
54+
"""Greedy policy playing the inlining_default action."""
55+
return np.array(state.observation['inlining_default'])
56+
57+
58+
@gin.register
59+
def explore_on_avail_external(state_observation: tf.Tensor) -> bool:
60+
return state_observation.numpy()[0]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Module for running compilation and collect data for behavior cloning."""
16+
17+
import functools
18+
from absl import app
19+
from absl import flags
20+
from absl import logging
21+
import gin
22+
23+
from compiler_opt.rl.imitation_learning import generate_bc_trajectories_lib
24+
from compiler_opt.rl.inlining import imitation_learning_config
25+
26+
from tf_agents.system import system_multiprocessing as multiprocessing
27+
28+
_GIN_FILES = flags.DEFINE_multi_string(
29+
'gin_files', [], 'List of paths to gin configuration files.')
30+
_GIN_BINDINGS = flags.DEFINE_multi_string(
31+
'gin_bindings', [],
32+
'Gin bindings to override the values set in the config files.')
33+
34+
35+
def main(_):
36+
gin.parse_config_files_and_bindings(
37+
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=True)
38+
logging.info(gin.config_str())
39+
40+
generate_bc_trajectories_lib.gen_trajectories(
41+
callable_policies=[imitation_learning_config.greedy_policy],
42+
explore_on_features={
43+
'is_callee_avail_external':
44+
imitation_learning_config.explore_on_avail_external
45+
})
46+
47+
48+
if __name__ == '__main__':
49+
multiprocessing.handle_main(functools.partial(app.run, main))

0 commit comments

Comments
 (0)