Skip to content

Commit bedba37

Browse files
authored
Explicitly pass the worker's gin config (#166)
If the local worker manager's processes aren't forked (which depends on what the default for `multiprocessing.get_context` is) the gin bindings won't be available. So we just pass them explicitly.
1 parent b4f9af7 commit bedba37

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from absl import logging
3838
# pylint: disable=unused-import
39-
from compiler_opt.distributed.worker import Worker, FixedWorkerPool
39+
from compiler_opt.distributed import worker
4040

4141
from contextlib import AbstractContextManager
4242
from multiprocessing import connection
@@ -59,8 +59,8 @@ class TaskResult:
5959
value: Any
6060

6161

62-
def _run_impl(pipe: connection.Connection, worker_class: 'type[Worker]', *args,
63-
**kwargs):
62+
def _run_impl(pipe: connection.Connection, worker_class: 'type[worker.Worker]',
63+
*args, **kwargs):
6464
"""Worker process entrypoint."""
6565

6666
# A setting of 1 does not inhibit the while loop below from running since
@@ -111,7 +111,7 @@ def _run(*args, **kwargs):
111111
raise e
112112

113113

114-
def _make_stub(cls: 'type[Worker]', *args, **kwargs):
114+
def _make_stub(cls: 'type[worker.Worker]', *args, **kwargs):
115115

116116
class _Stub:
117117
"""Client stub to a worker hosted by a process."""
@@ -241,16 +241,17 @@ def __dir__(self):
241241
class LocalWorkerPoolManager(AbstractContextManager):
242242
"""A pool of workers hosted on the local machines, each in its own process."""
243243

244-
def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
245-
**kwargs):
244+
def __init__(self, worker_class: 'type[worker.Worker]', count: Optional[int],
245+
*args, **kwargs):
246246
if not count:
247247
count = multiprocessing.get_context().cpu_count()
248+
final_kwargs = worker.get_full_worker_args(worker_class, kwargs)
248249
self._stubs = [
249-
_make_stub(worker_class, *args, **kwargs) for _ in range(count)
250+
_make_stub(worker_class, *args, **final_kwargs) for _ in range(count)
250251
]
251252

252-
def __enter__(self) -> FixedWorkerPool:
253-
return FixedWorkerPool(workers=self._stubs, worker_concurrency=10)
253+
def __enter__(self) -> worker.FixedWorkerPool:
254+
return worker.FixedWorkerPool(workers=self._stubs, worker_concurrency=10)
254255

255256
def __exit__(self, *args):
256257
# first, trigger killing the worker process and exiting of the msg pump,

compiler_opt/distributed/worker.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
"""Common abstraction for a worker contract."""
1616

1717
import abc
18+
import sys
1819
from typing import Any, List, Iterable, Optional, Protocol, TypeVar
1920

21+
import gin
22+
2023

2124
class Worker(Protocol):
2225

@@ -86,3 +89,23 @@ def get_exception(worker_future: WorkerFuture) -> Optional[Exception]:
8689
return None
8790
except Exception as e: # pylint: disable=broad-except
8891
return e
92+
93+
94+
def get_full_worker_args(worker_class: 'type[Worker]', current_kwargs):
95+
"""Get the union of given kwargs and gin config.
96+
97+
This allows the worker hosting process be set up differently from the training
98+
process - e.g. no need to initialize gin variables there, for example.
99+
"""
100+
gin_config = {}
101+
try:
102+
gin_config = gin.get_bindings(worker_class)
103+
except ValueError:
104+
# we don't have a way to check if `worker_class` is even known to gin, and
105+
# it's not a requirement that it were. Tests, for instance, don't use gin.
106+
pass
107+
# Issue #38
108+
if sys.version_info.minor >= 9:
109+
return current_kwargs | gin_config
110+
else:
111+
return {**current_kwargs, **gin_config}

0 commit comments

Comments
 (0)