Skip to content

Commit e28bf48

Browse files
Add argument to append options to modelling invocations (#474)
This patch adds a new argument to regalloc_trace_worker that can be set in a gin file for appending additional arguments to basic_block_trace_model.
1 parent e4446c1 commit e28bf48

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

compiler_opt/es/regalloc_trace/regalloc_trace_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __init__(
125125
corpus_path: str,
126126
copy_corpus_locally_path: str | None = None,
127127
aux_file_replacement_flags: dict[str, str] | None = None,
128+
extra_bb_trace_model_flags: list[str] | None = None,
128129
):
129130
"""Initializes the RegallocTraceWorker class.
130131
@@ -146,10 +147,14 @@ def __init__(
146147
local to the worker. This is intended to be used in distributed
147148
training setups where training corpora and auxiliary files need to be
148149
copied locally before being compiled.
150+
extra_bb_trace_model_flags: Extra flags to pass to the
151+
basic_block_trace_model invocation.
149152
"""
150153
self._clang_path = clang_path
151154
self._basic_block_trace_model_path = basic_block_trace_model_path
152155
self._thread_count = thread_count
156+
self._extra_bb_trace_model_flags = ([] if extra_bb_trace_model_flags is None
157+
else extra_bb_trace_model_flags)
153158

154159
self._has_local_corpus = False
155160
self._corpus_path = corpus_path
@@ -257,6 +262,7 @@ def _evaluate_corpus(self, module_directory: str, function_index_path: str,
257262
f"--thread_count={self._thread_count}",
258263
f"--bb_trace_path={bb_trace_path}", "--model_type=mca"
259264
]
265+
command_vector.extend(self._extra_bb_trace_model_flags)
260266

261267
output = subprocess.run(command_vector, capture_output=True, check=True)
262268

compiler_opt/es/regalloc_trace/regalloc_trace_worker_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,33 @@ def test_remote_corpus_replacement_flags(self):
252252
f"-fprofile-instr-use={copied_profile_path}" in clang_command_lines[0])
253253
self.assertTrue(
254254
f"-fprofile-instr-use={copied_profile_path}" in clang_command_lines[1])
255+
256+
def test_extra_bb_trace_flags(self):
257+
corpus_dir = self.create_tempdir("corpus")
258+
corpus_modules = _setup_corpus(corpus_dir.full_path)
259+
fake_clang_binary = self.create_tempfile("fake_clang")
260+
fake_clang_invocations = self.create_tempfile("fake_clang_invocations")
261+
_create_test_binary(fake_clang_binary.full_path,
262+
fake_clang_invocations.full_path)
263+
fake_bb_trace_model_binary = self.create_tempfile(
264+
"fake_basic_block_trace_model")
265+
fake_bb_trace_model_invocations = self.create_tempfile(
266+
"fake_basic_block_trace_model_invocations")
267+
_create_test_binary(fake_bb_trace_model_binary.full_path,
268+
fake_bb_trace_model_invocations.full_path)
269+
270+
worker = regalloc_trace_worker.RegallocTraceWorker(
271+
gin_config="",
272+
clang_path=fake_clang_binary.full_path,
273+
basic_block_trace_model_path=fake_bb_trace_model_binary.full_path,
274+
thread_count=1,
275+
corpus_path=corpus_dir.full_path,
276+
extra_bb_trace_model_flags=["--extra_flag"])
277+
_ = worker.compile_corpus_and_evaluate(corpus_modules,
278+
"function_index_path.pb",
279+
"bb_trace_path.pb", None)
280+
281+
command_line = fake_bb_trace_model_invocations.read_text().split(
282+
"\n")[0].split()
283+
284+
self.assertTrue("--extra_flag" in command_line)

0 commit comments

Comments
 (0)