Skip to content

Commit 22e6b7a

Browse files
Refactor add/delete flags options to problem_configuration (#84)
This patch moves the add/delete flags from a gin config option to the problem configuration python object and makes generate_default_trace grab the flags from there.
1 parent 55cfce9 commit 22e6b7a

File tree

6 files changed

+30
-35
lines changed

6 files changed

+30
-35
lines changed

compiler_opt/rl/inlining/gin_configs/ppo_nn_agent.gin

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,6 @@ train_eval.deploy_policy_name='saved_collect_policy'
2020
train_eval.use_random_network_distillation=False
2121
train_eval.moving_average_decay_rate=0.8
2222

23-
# TODO(b/233935329): The following clang flags need to be tied to a corpus
24-
# rather than to a training tool invocation.
25-
26-
# List of flags to add to clang compilation command. The flag names should
27-
# match the actual flags provided to clang. An example for AFDO reinjection:
28-
# train_eval.flags_to_add=('-fprofile-sample-use=/path/to/gwp.afdo','-fprofile-remapping-file=/path/to/prof_remap.txt')
29-
train_eval.additional_compilation_flags=()
30-
31-
# List of flags to remove from clang compilation command. The flag names
32-
# should match the actual flags provided to clang.'
33-
train_eval.delete_compilation_flags=('-split-dwarf-file','-split-dwarf-output','-fthinlto-index','-fprofile-sample-use','-fprofile-remapping-file')
34-
3523
# RandomNetworkDistillation configs, off if train_eval.use_random_network_distillation=False.
3624
RandomNetworkDistillation.encoding_network = @encoding_network.EncodingNetwork
3725
RandomNetworkDistillation.learning_rate = 1e-4

compiler_opt/rl/problem_configuration.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,19 @@ def get_nonnormalized_features(self) -> Iterable[str]:
9999
@abc.abstractmethod
100100
def get_runner_type(self) -> 'type[compilation_runner.CompilationRunner]':
101101
raise NotImplementedError
102+
103+
# TODO(b/233935329): The following clang flags need to be tied to a corpus
104+
# rather than to a training tool invocation.
105+
106+
# List of flags to add to clang compilation command. The flag names should
107+
# match the actual flags provided to clang. An example for AFDO reinjection:
108+
# return ['-fprofile-sample-use=/path/to/gwp.afdo',
109+
# '-fprofile-remapping-file=/path/to/prof_remap.txt']
110+
def flags_to_add(self) -> Tuple[str, ...]:
111+
return ()
112+
113+
# List of flags to remove from clang compilation command. The flag names
114+
# should match the actual flags provided to clang.'
115+
def flags_to_delete(self) -> Tuple[str, ...]:
116+
return ('-split-dwarf-file', '-split-dwarf-output', '-fthinlto-index',
117+
'-fprofile-sample-use', '-fprofile-remapping-file')

compiler_opt/rl/regalloc/gin_configs/ppo_nn_agent.gin

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,6 @@ train_eval.deploy_policy_name='saved_collect_policy'
2121
train_eval.moving_average_decay_rate=0.8
2222
train_eval.use_random_network_distillation=False
2323

24-
# TODO(b/233935329): The following clang flags need to be tied to a corpus
25-
# rather than to a training tool invocation.
26-
27-
# List of flags to add to clang compilation command. The flag names should
28-
# match the actual flags provided to clang. An example for AFDO reinjection:
29-
# train_eval.additional_flags=('-fprofile-sample-use=/path/to/gwp.afdo','-fprofile-remapping-file=/path/to/prof_remap.txt')
30-
train_eval.additional_compilation_flags=()
31-
32-
# List of flags to remove from clang compilation command. The flag names
33-
# should match the actual flags provided to clang.'
34-
train_eval.delete_compilation_flags=('-split-dwarf-file','-split-dwarf-output','-fthinlto-index','-fprofile-sample-use','-fprofile-remapping-file')
35-
3624
# RandomNetworkDistillation configs, off if train_eval.use_random_network_distillation=False.
3725
RandomNetworkDistillation.encoding_network = @regalloc_network.RegAllocRNDEncodingNetwork
3826
RandomNetworkDistillation.learning_rate = 1e-4

compiler_opt/rl/train_locally.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ def train_eval(agent_name=constant.AgentName.PPO,
6868
train_sequence_length=1,
6969
deploy_policy_name='saved_policy',
7070
use_random_network_distillation=False,
71-
moving_average_decay_rate=1,
72-
additional_compilation_flags=(),
73-
delete_compilation_flags=()):
71+
moving_average_decay_rate=1):
7472
"""Train for LLVM inliner."""
7573
root_dir = FLAGS.root_dir
7674
problem_config = registry.get_configuration()
@@ -102,7 +100,8 @@ def train_eval(agent_name=constant.AgentName.PPO,
102100

103101
logging.info('Loading module specs from corpus.')
104102
module_specs = corpus.build_modulespecs_from_datapath(
105-
FLAGS.data_path, additional_compilation_flags, delete_compilation_flags)
103+
FLAGS.data_path, problem_config.flags_to_add(),
104+
problem_config.flags_to_delete())
106105
logging.info('Done loading module specs from corpus.')
107106

108107
dataset_fn = data_reader.create_sequence_example_dataset_fn(

compiler_opt/tools/generate_default_trace.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,13 @@ def main(_):
136136
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)
137137
logging.info(gin.config_str())
138138

139+
config = registry.get_configuration()
140+
139141
logging.info('Loading module specs from corpus.')
140142
module_specs = corpus.build_modulespecs_from_datapath(
141143
_DATA_PATH.value,
142-
delete_flags=('-split-dwarf-file', '-split-dwarf-output',
143-
'-fthinlto-index', '-fprofile-sample-use',
144-
'-fprofile-remapping-file'))
144+
additional_flags=config.flags_to_add(),
145+
delete_flags=config.flags_to_delete())
145146
logging.info('Done loading module specs from corpus.')
146147

147148
if _MODULE_FILTER.value:

compiler_opt/tools/generate_default_trace_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,16 @@ def collect_data(self, module_spec, tf_policy_path, reward_stat):
6060

6161

6262
class GenerateDefaultTraceTest(absltest.TestCase):
63+
def setUp(self):
64+
with gin.unlock_config():
65+
gin.parse_config_files_and_bindings(
66+
config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'],
67+
bindings=None)
68+
return super().setUp()
6369

6470
@mock.patch('compiler_opt.tools.generate_default_trace.get_runner')
6571
def test_api(self, mock_get_runner):
72+
6673
tmp_dir = self.create_tempdir()
6774
module_names = ['a', 'b', 'c', 'd']
6875

@@ -92,10 +99,6 @@ def test_api(self, mock_get_runner):
9299
generate_default_trace.main(None)
93100

94101
def test_get_runner(self):
95-
with gin.unlock_config():
96-
gin.parse_config_files_and_bindings(
97-
config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'],
98-
bindings=None)
99102
runner = generate_default_trace.get_runner()
100103
self.assertIsInstance(runner, compilation_runner.CompilationRunner)
101104

0 commit comments

Comments
 (0)