Skip to content

Commit ae6accb

Browse files
authored
Make the corpus remote wrt workers (#141)
* Make the corpus remote wrt workers This change removes the assumption that the corpus is co-located with the workers. The new design relies on the assumption that the workers are reachable via a high-throughput network (datacenter setup). The trainer prefetches the modules for the next iteration and sends them as data blobs to the workers. The rest of the change trickles out from this: - the corpus operates internally with a metadata object about modules, (`ModuleSpec`) - workers are given a `LoadedModuleSpec` which contains the data of the module (and its thinlto index, if needed). - command line final shape is done on the worker side This had the side-effect that responsibilities could be reassigned as follows: - the corpus constructor handles preparing the `ModuleSpec`s, applying all the necessary flag modification handlers (i.e. add/delete/replace) - flag modification logic is generic - it doesn't know anything about thinlto, '-cc1', etc. - flag contextualization happens on the worker side. Flags reference a context object as a string formatting value, which is then provided by the worker. In turn, this impacted testing. This change impacts purely local training similar to how a previous change remoting the policy does - local files are copied by workers redundantly. It doesn not appear to have any impact on training performance, but should we want to, we can optimize the local scenario under the hood.
1 parent b04bef6 commit ae6accb

11 files changed

+616
-530
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from typing import Dict, List, Optional, Tuple
2525

2626
from absl import flags
27-
from compiler_opt.distributed.worker import Worker, WorkerFuture
27+
from absl import logging
28+
from compiler_opt.distributed.worker import Worker
29+
from compiler_opt.distributed.worker import WorkerFuture
2830
from compiler_opt.rl import constant
2931
from compiler_opt.rl import policy_saver
3032
from compiler_opt.rl import corpus
@@ -178,6 +180,8 @@ def start_cancellable_process(
178180
# Disable tensorflow info messages during data collection
179181
if _QUIET.value:
180182
command_env['TF_CPP_MIN_LOG_LEVEL'] = '1'
183+
else:
184+
logging.info(cmdline)
181185
with subprocess.Popen(
182186
cmdline,
183187
env=command_env,
@@ -251,7 +255,7 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
251255
@abc.abstractmethod
252256
def collect_data(
253257
self,
254-
module_spec: corpus.ModuleSpec,
258+
loaded_module_spec: corpus.LoadedModuleSpec,
255259
policy: Optional[policy_saver.Policy] = None,
256260
reward_stat: Optional[Dict[str, RewardStat]] = None
257261
) -> WorkerFuture[CompilationResult]:
@@ -310,13 +314,13 @@ def resume_all_work(self):
310314

311315
def collect_data(
312316
self,
313-
module_spec: corpus.ModuleSpec,
317+
loaded_module_spec: corpus.LoadedModuleSpec,
314318
policy: Optional[policy_saver.Policy] = None,
315319
reward_stat: Optional[Dict[str, RewardStat]] = None) -> CompilationResult:
316320
"""Collect data for the given IR file and policy.
317321
318322
Args:
319-
module_spec: a ModuleSpec.
323+
loaded_module_spec: a LoadedModuleSpec.
320324
policy: serialized policy.
321325
reward_stat: reward stat of this module, None if unknown.
322326
@@ -331,21 +335,22 @@ def collect_data(
331335
ValueError if example under default policy and ml policy does not match.
332336
"""
333337
with tempfile.TemporaryDirectory() as tempdir:
338+
final_cmd_line = loaded_module_spec.build_command_line(tempdir)
334339
tf_policy_path = ''
335340
if policy is not None:
336341
tf_policy_path = os.path.join(tempdir, 'policy')
337342
policy.to_filesystem(tf_policy_path)
338343

339344
if reward_stat is None:
340345
default_result = self.compile_fn(
341-
module_spec, tf_policy_path='', reward_only=bool(tf_policy_path))
346+
final_cmd_line, tf_policy_path='', reward_only=bool(tf_policy_path))
342347
reward_stat = {
343348
k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
344349
}
345350

346351
if tf_policy_path:
347352
policy_result = self.compile_fn(
348-
module_spec, tf_policy_path, reward_only=False)
353+
final_cmd_line, tf_policy_path, reward_only=False)
349354
else:
350355
policy_result = default_result
351356

@@ -358,7 +363,7 @@ def collect_data(
358363
if k not in reward_stat:
359364
raise ValueError(
360365
(f'Example {k} does not exist under default policy for '
361-
f'module {module_spec.name}'))
366+
f'cmd line: {final_cmd_line}'))
362367
default_reward = reward_stat[k].default_reward
363368
moving_average_reward = reward_stat[k].moving_average_reward
364369
sequence_example = _overwrite_trajectory_reward(
@@ -380,12 +385,12 @@ def collect_data(
380385
keys=keys)
381386

382387
def compile_fn(
383-
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
388+
self, command_line: corpus.FullyQualifiedCmdLine, tf_policy_path: str,
384389
reward_only: bool) -> Dict[str, Tuple[tf.train.SequenceExample, float]]:
385390
"""Compiles for the given IR file under the given policy.
386391
387392
Args:
388-
module_spec: a ModuleSpec.
393+
command_line: the fully qualified command line.
389394
tf_policy_path: path to TF policy directory on local disk.
390395
reward_only: whether only return reward.
391396

compiler_opt/rl/compilation_runner_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def _mock_compile_fn(file_paths, tf_policy_path, reward_only): # pylint: disabl
9494

9595
_mock_policy = policy_saver.Policy(bytes(), bytes())
9696

97+
_mock_loaded_module_spec = corpus.LoadedModuleSpec(
98+
name='dummy', loaded_ir=bytes())
99+
97100

98101
class CompilationRunnerTest(tf.test.TestCase):
99102

@@ -111,7 +114,7 @@ def test_policy(self, mock_compile_fn):
111114
runner = compilation_runner.CompilationRunner(
112115
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
113116
data = runner.collect_data(
114-
module_spec=corpus.ModuleSpec(name='dummy'), policy=_mock_policy)
117+
loaded_module_spec=_mock_loaded_module_spec, policy=_mock_policy)
115118
self.assertEqual(2, mock_compile_fn.call_count)
116119

117120
expected_example = _get_sequence_example_with_reward(
@@ -139,7 +142,7 @@ def test_default(self, mock_compile_fn):
139142
runner = compilation_runner.CompilationRunner(
140143
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
141144

142-
data = runner.collect_data(module_spec=corpus.ModuleSpec(name='dummy'))
145+
data = runner.collect_data(loaded_module_spec=_mock_loaded_module_spec)
143146
# One call when we ask for the default policy, because it can provide both
144147
# trace and default size.
145148
self.assertEqual(1, mock_compile_fn.call_count)
@@ -168,7 +171,7 @@ def test_given_default_size(self, mock_compile_fn):
168171
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
169172

170173
data = runner.collect_data(
171-
module_spec=corpus.ModuleSpec(name='dummy'),
174+
loaded_module_spec=_mock_loaded_module_spec,
172175
policy=_mock_policy,
173176
reward_stat={
174177
'default':
@@ -205,7 +208,7 @@ def test_exception_handling(self, mock_compile_fn):
205208

206209
with self.assertRaisesRegex(subprocess.CalledProcessError, 'error'):
207210
_ = runner.collect_data(
208-
module_spec=corpus.ModuleSpec(name='dummy'),
211+
loaded_module_spec=_mock_loaded_module_spec,
209212
policy=_mock_policy,
210213
reward_stat=None)
211214
self.assertEqual(1, mock_compile_fn.call_count)

0 commit comments

Comments
 (0)