Skip to content

Commit 2136716

Browse files
authored
[local worker] Parse sys.argv in child processes (#449)
1 parent 933202c commit 2136716

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@
2727
local thread pool, or, if the task is 'urgent', it executes it promptly.
2828
"""
2929
import concurrent.futures
30+
import sys
3031
import cloudpickle
3132
import dataclasses
3233
import functools
3334
import multiprocessing
3435
import psutil
3536
import threading
3637

37-
from absl import logging
38+
from absl import flags, logging
3839
# pylint: disable=unused-import
3940
from compiler_opt.distributed import worker
4041

@@ -67,8 +68,8 @@ def _get_context():
6768
SerializedClass = bytes
6869

6970

70-
def _run_impl(pipe: connection.Connection, worker_class: SerializedClass, *args,
71-
**kwargs):
71+
def _run_impl(pipe: connection.Connection, worker_class: SerializedClass,
72+
parse_argv: bool, *args, **kwargs):
7273
"""Worker process entrypoint."""
7374

7475
# A setting of 1 does not inhibit the while loop below from running since
@@ -77,7 +78,12 @@ def _run_impl(pipe: connection.Connection, worker_class: SerializedClass, *args,
7778
# spawned at a time which execute given tasks. In the typical clang-spawning
7879
# jobs, this effectively limits the number of clang instances spawned.
7980
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
80-
obj = cloudpickle.loads(worker_class)(*args, **kwargs)
81+
# unpickle the type, which triggers all the necessary imports. This registers
82+
# abseil flags, which can then be parsed.
83+
cls = cloudpickle.loads(worker_class)
84+
if parse_argv:
85+
flags.FLAGS(sys.argv, known_only=True)
86+
obj = cls(*args, **kwargs)
8187

8288
# Pipes are not thread safe
8389
pipe_lock = threading.Lock()
@@ -111,16 +117,16 @@ def on_done(f: concurrent.futures.Future):
111117
pool.submit(application).add_done_callback(make_ondone(task.msgid))
112118

113119

114-
def _run(pipe: connection.Connection, worker_class: SerializedClass, *args,
115-
**kwargs):
120+
def _run(pipe: connection.Connection, worker_class: SerializedClass,
121+
parse_argv: bool, *args, **kwargs):
116122
try:
117-
_run_impl(pipe, worker_class, *args, **kwargs)
123+
_run_impl(pipe, worker_class, parse_argv, *args, **kwargs)
118124
except BaseException as e:
119125
logging.error(e)
120126
raise e
121127

122128

123-
def _make_stub(cls: 'type[worker.Worker]', *args, **kwargs):
129+
def _make_stub(cls: 'type[worker.Worker]', parse_argv: bool, *args, **kwargs):
124130

125131
class _Stub:
126132
"""Client stub to a worker hosted by a process."""
@@ -135,8 +141,8 @@ def __init__(self):
135141
# to handle high priority requests. The expectation is that the user
136142
# achieves concurrency through multiprocessing, not multithreading.
137143
self._process = _get_context().Process(
138-
target=functools.partial(_run, child_pipe, cloudpickle.dumps(cls), *
139-
args, **kwargs))
144+
target=functools.partial(_run, child_pipe, cloudpickle.dumps(cls),
145+
parse_argv, *args, **kwargs))
140146
# lock for the msgid -> reply future map. The map will be set to None
141147
# when we stop.
142148
self._lock = threading.Lock()
@@ -248,13 +254,16 @@ def __dir__(self):
248254

249255

250256
def create_local_worker_pool(worker_cls: 'type[worker.Worker]',
251-
count: int | None, *args,
257+
count: int | None, parse_argv: bool, *args,
252258
**kwargs) -> worker.FixedWorkerPool:
253259
"""Create a local worker pool for worker_cls."""
254260
if not count:
255261
count = _get_context().cpu_count()
256262
final_kwargs = worker.get_full_worker_args(worker_cls, **kwargs)
257-
stubs = [_make_stub(worker_cls, *args, **final_kwargs) for _ in range(count)]
263+
stubs = [
264+
_make_stub(worker_cls, parse_argv, *args, **final_kwargs)
265+
for _ in range(count)
266+
]
258267
return worker.FixedWorkerPool(workers=stubs, worker_concurrency=16)
259268

260269

@@ -274,7 +283,8 @@ class LocalWorkerPoolManager(AbstractContextManager):
274283

275284
def __init__(self, worker_class: 'type[worker.Worker]', count: int | None,
276285
*args, **kwargs):
277-
self._pool = create_local_worker_pool(worker_class, count, *args, **kwargs)
286+
self._pool = create_local_worker_pool(worker_class, count, True, *args,
287+
**kwargs)
278288

279289
def __enter__(self) -> worker.FixedWorkerPool:
280290
return self._pool

compiler_opt/distributed/local/local_worker_manager_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414
"""Test for local worker manager."""
1515

1616
import concurrent.futures
17+
import sys
1718
import time
19+
from unittest import mock
1820

21+
from absl import flags
1922
from absl.testing import absltest
2023
from compiler_opt.distributed.worker import Worker
2124
from compiler_opt.distributed.local import local_worker_manager
2225
from tf_agents.system import system_multiprocessing as multiprocessing
2326

27+
_TEST_FLAG = flags.DEFINE_integer(
28+
'test_only_flag', default=1, help='A flag used by some tests.')
29+
2430

2531
class JobNormal(Worker):
2632
"""Test worker."""
@@ -65,6 +71,12 @@ def method(self):
6571
time.sleep(3600)
6672

6773

74+
class JobGetFlags(Worker):
75+
76+
def method(self):
77+
return {'argv': sys.argv, 'the_flag': _TEST_FLAG.value}
78+
79+
6880
class LocalWorkerManagerTest(absltest.TestCase):
6981

7082
def test_pool(self):
@@ -114,6 +126,16 @@ def test_worker_crash_while_waiting(self):
114126
with self.assertRaises(concurrent.futures.CancelledError):
115127
_ = f.result()
116128

129+
def test_flag_parsing(self):
130+
with local_worker_manager.LocalWorkerPoolManager(JobGetFlags, 1) as pool:
131+
result = pool.get_currently_active()[0].method().result()
132+
self.assertEqual(result['the_flag'], 1)
133+
134+
with mock.patch('sys.argv', sys.argv + ['--test_only_flag=42']):
135+
with local_worker_manager.LocalWorkerPoolManager(JobGetFlags, 1) as pool:
136+
result = pool.get_currently_active()[0].method().result()
137+
self.assertEqual(result['the_flag'], 42)
138+
117139

118140
if __name__ == '__main__':
119141
multiprocessing.handle_test_main(absltest.main)

0 commit comments

Comments
 (0)