Skip to content

Commit 75eb0fd

Browse files
authored
Separate worker and worker manager args (#453)
Without this, it's hard to extend the argument list for the local worker manager.
1 parent 9ba929b commit 75eb0fd

File tree

7 files changed

+58
-35
lines changed

7 files changed

+58
-35
lines changed

compiler_opt/distributed/buffered_scheduler_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def wait_seconds(self, n: int):
3434
time.sleep(n)
3535
return n + 1
3636

37-
with local_worker_manager.LocalWorkerPoolManager(WaitingWorker, 2) as pool:
37+
with local_worker_manager.LocalWorkerPoolManager(
38+
WaitingWorker, count=2) as pool:
3839
_, futures = buffered_scheduler.schedule_on_worker_pool(
3940
lambda w, v: w.wait_seconds(v), range(4), pool)
4041
not_done = futures
@@ -52,7 +53,8 @@ class TheWorker(worker.Worker):
5253
def square(self, the_value, extra_factor=1):
5354
return the_value * the_value * extra_factor
5455

55-
with local_worker_manager.LocalWorkerPoolManager(TheWorker, 2) as pool:
56+
with local_worker_manager.LocalWorkerPoolManager(
57+
TheWorker, count=2) as pool:
5658
workers, futures = buffered_scheduler.schedule_on_worker_pool(
5759
lambda w, v: w.square(v), range(10), pool)
5860
self.assertLen(workers, 2)

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ def __dir__(self):
253253
return _Stub()
254254

255255

256-
def create_local_worker_pool(worker_cls: 'type[worker.Worker]',
257-
count: int | None, parse_argv: bool, *args,
258-
**kwargs) -> worker.FixedWorkerPool:
256+
def _create_local_worker_pool(worker_cls: 'type[worker.Worker]',
257+
count: int | None, parse_argv: bool, *args,
258+
**kwargs) -> worker.FixedWorkerPool:
259259
"""Create a local worker pool for worker_cls."""
260260
if not count:
261261
count = _get_context().cpu_count()
@@ -267,7 +267,7 @@ def create_local_worker_pool(worker_cls: 'type[worker.Worker]',
267267
return worker.FixedWorkerPool(workers=stubs, worker_concurrency=16)
268268

269269

270-
def close_local_worker_pool(pool: worker.FixedWorkerPool):
270+
def _close_local_worker_pool(pool: worker.FixedWorkerPool):
271271
"""Close the given LocalWorkerPool."""
272272
# first, trigger killing the worker process and exiting of the msg pump,
273273
# which will also clear out any pending futures.
@@ -281,16 +281,21 @@ def close_local_worker_pool(pool: worker.FixedWorkerPool):
281281
class LocalWorkerPoolManager(AbstractContextManager):
282282
"""A pool of workers hosted on the local machines, each in its own process."""
283283

284-
def __init__(self, worker_class: 'type[worker.Worker]', count: int | None,
285-
*args, **kwargs):
286-
self._pool = create_local_worker_pool(worker_class, count, True, *args,
287-
**kwargs)
284+
def __init__(self,
285+
worker_class: 'type[worker.Worker]',
286+
*,
287+
count: int | None,
288+
worker_args: tuple = (),
289+
worker_kwargs: dict | None = None):
290+
worker_kwargs = {} if worker_kwargs is None else worker_kwargs
291+
self._pool = _create_local_worker_pool(worker_class, count, True,
292+
*worker_args, **worker_kwargs)
288293

289294
def __enter__(self) -> worker.FixedWorkerPool:
290295
return self._pool
291296

292297
def __exit__(self, *args):
293-
close_local_worker_pool(self._pool)
298+
_close_local_worker_pool(self._pool)
294299

295300
def __del__(self):
296301
self.__exit__()

compiler_opt/distributed/local/local_worker_manager_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def test_pool(self):
8585
kwarg = 'bar'
8686

8787
with local_worker_manager.LocalWorkerPoolManager(
88-
JobNormal, 2, arg, kwarg=kwarg) as pool:
88+
JobNormal, count=2, worker_args=(arg,),
89+
worker_kwargs=dict(kwarg=kwarg)) as pool:
8990
p1 = pool.get_currently_active()[0]
9091
p2 = pool.get_currently_active()[1]
9192
set_futures = [p1.set_token(1), p2.set_token(2)]
@@ -108,15 +109,15 @@ def test_pool(self):
108109

109110
def test_failure(self):
110111

111-
with local_worker_manager.LocalWorkerPoolManager(JobFail, 2) as pool:
112+
with local_worker_manager.LocalWorkerPoolManager(JobFail, count=2) as pool:
112113
with self.assertRaises(concurrent.futures.CancelledError):
113114
# this will fail because we didn't pass the arg to the ctor, so the
114115
# worker hosting process will crash.
115116
pool.get_currently_active()[0].method().result()
116117

117118
def test_worker_crash_while_waiting(self):
118119

119-
with local_worker_manager.LocalWorkerPoolManager(JobSlow, 2) as pool:
120+
with local_worker_manager.LocalWorkerPoolManager(JobSlow, count=2) as pool:
120121
p = pool.get_currently_active()[0]
121122
f = p.method()
122123
self.assertFalse(f.done())
@@ -127,12 +128,14 @@ def test_worker_crash_while_waiting(self):
127128
_ = f.result()
128129

129130
def test_flag_parsing(self):
130-
with local_worker_manager.LocalWorkerPoolManager(JobGetFlags, 1) as pool:
131+
with local_worker_manager.LocalWorkerPoolManager(
132+
JobGetFlags, count=1) as pool:
131133
result = pool.get_currently_active()[0].method().result()
132134
self.assertEqual(result['the_flag'], 1)
133135

134136
with mock.patch('sys.argv', sys.argv + ['--test_only_flag=42']):
135-
with local_worker_manager.LocalWorkerPoolManager(JobGetFlags, 1) as pool:
137+
with local_worker_manager.LocalWorkerPoolManager(
138+
JobGetFlags, count=1) as pool:
136139
result = pool.get_currently_active()[0].method().result()
137140
self.assertEqual(result['the_flag'], 42)
138141

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ class BlackboxEvaluatorTests(absltest.TestCase):
2828

2929
def test_sampling_get_results(self):
3030
with local_worker_manager.LocalWorkerPoolManager(
31-
blackbox_test_utils.ESWorker, count=3, arg='', kwarg='') as pool:
31+
blackbox_test_utils.ESWorker,
32+
count=3,
33+
worker_args=('',),
34+
worker_kwargs=dict(kwarg='')) as pool:
3235
perturbations = [b'00', b'01', b'10']
3336
evaluator = blackbox_evaluator.SamplingBlackboxEvaluator(None, 5, 5, None)
3437
# pylint: disable=protected-access
@@ -52,7 +55,10 @@ def test_get_rewards(self):
5255

5356
def test_trace_get_results(self):
5457
with local_worker_manager.LocalWorkerPoolManager(
55-
blackbox_test_utils.ESTraceWorker, count=3, arg='', kwarg='') as pool:
58+
blackbox_test_utils.ESTraceWorker,
59+
count=3,
60+
worker_args=('',),
61+
worker_kwargs=dict(kwarg='')) as pool:
5662
perturbations = [b'00', b'01', b'10']
5763
test_corpus = corpus.create_corpus_for_testing(
5864
location=self.create_tempdir().full_path,
@@ -65,7 +71,10 @@ def test_trace_get_results(self):
6571

6672
def test_trace_set_baseline(self):
6773
with local_worker_manager.LocalWorkerPoolManager(
68-
blackbox_test_utils.ESTraceWorker, count=1, arg='', kwarg='') as pool:
74+
blackbox_test_utils.ESTraceWorker,
75+
count=1,
76+
worker_args=('',),
77+
worker_kwargs=dict(kwarg='')) as pool:
6978
test_corpus = corpus.create_corpus_for_testing(
7079
location=self.create_tempdir().full_path,
7180
elements=[corpus.ModuleSpec(name='name1', size=1)])

compiler_opt/es/blackbox_learner_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,10 @@ def test_prune_skipped_perturbations(self):
146146

147147
def test_run_step(self):
148148
with local_worker_manager.LocalWorkerPoolManager(
149-
blackbox_test_utils.ESWorker, count=3, arg='', kwarg='') as pool:
149+
blackbox_test_utils.ESWorker,
150+
count=3,
151+
worker_args=('',),
152+
worker_kwargs=dict(kwarg='')) as pool:
150153
self._learner.run_step(pool) # pylint: disable=protected-access
151154
# expected length calculated from expected shapes of variables
152155
self.assertEqual(len(self._learner.get_model_weights()), 17154)

compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -955,15 +955,15 @@ def gen_trajectories(
955955
min(os.cpu_count(), num_workers) if num_workers else os.cpu_count())
956956
with worker_manager_class(
957957
worker_class_type,
958-
worker_count,
959-
obs_action_specs=obs_action_spec,
960-
mlgo_task_type=mlgo_task_type,
961-
callable_policies=callable_policies,
962-
explore_on_features=explore_on_features,
963-
persistent_objects_path=persistent_objects_path,
964-
explicit_temps_dir=explicit_temps_dir,
965-
gin_config_str=gin.config_str(),
966-
) as lwm:
958+
count=worker_count,
959+
worker_kwargs=dict(
960+
obs_action_specs=obs_action_spec,
961+
mlgo_task_type=mlgo_task_type,
962+
callable_policies=callable_policies,
963+
explore_on_features=explore_on_features,
964+
persistent_objects_path=persistent_objects_path,
965+
explicit_temps_dir=explicit_temps_dir,
966+
gin_config_str=gin.config_str())) as lwm:
967967

968968
_, result_futures = buffered_scheduler.schedule_on_worker_pool(
969969
action=lambda w, j: w.select_best_exploration(loaded_module_spec=j),

compiler_opt/tools/generate_default_trace.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ def generate_trace(
151151
with performance_context as performance_writer:
152152
with worker_manager_class(
153153
FilteringWorker,
154-
_NUM_WORKERS.value,
155-
policy_path=_POLICY_PATH.value,
156-
key_filter=_KEY_FILTER.value,
157-
runner_type=runner_type,
158-
runner_kwargs=worker.get_full_worker_args(
159-
runner_type, moving_average_decay_rate=0)) as lwm:
154+
count=_NUM_WORKERS.value,
155+
worker_kwargs=dict(
156+
policy_path=_POLICY_PATH.value,
157+
key_filter=_KEY_FILTER.value,
158+
runner_type=runner_type,
159+
runner_kwargs=worker.get_full_worker_args(
160+
runner_type, moving_average_decay_rate=0))) as lwm:
160161

161162
_, result_futures = buffered_scheduler.schedule_on_worker_pool(
162163
action=lambda w, j: w.compile_and_filter(j),

0 commit comments

Comments
 (0)