Skip to content

Commit 6efe174

Browse files
authored
Set max workers for LocalWorkerManager (#75)
- Prevents attempting to spawn 32 * min(32, cores) number of clangs (python errors out) - max_worker count is adjustable via `pool_threads` parameter in LocalWorkerPool constructor
1 parent 154f491 commit 6efe174

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import queue # pylint: disable=unused-import
3636
import threading
3737

38+
import gin
3839
from absl import logging
3940
# pylint: disable=unused-import
4041
from compiler_opt.distributed.worker import Worker
@@ -60,12 +61,19 @@ class TaskResult:
6061

6162

6263
def _run_impl(in_q: 'queue.Queue[Task]', out_q: 'queue.Queue[TaskResult]',
63-
worker_class: 'type[Worker]', *args, **kwargs):
64+
worker_class: 'type[Worker]', threads: int, *args, **kwargs):
6465
"""Worker process entrypoint."""
6566
# Note: the out_q is typed as taking only TaskResult objects, not
6667
# Optional[TaskResult], despite that being the type it is used on the Stub
6768
# side. This is because the `None` value is only injected by the Stub itself.
68-
pool = concurrent.futures.ThreadPoolExecutor()
69+
70+
# `threads` is defaulted to 1 in LocalWorkerPool's constructor.
71+
# A setting of 1 does not inhibit the while loop below from running since
72+
# that runs on the main thread of the process. Urgent tasks will still
73+
# process near-immediately. `threads` only controls how many threads are
74+
# spawned at a time which execute given tasks. In the typical clang-spawning
75+
# jobs, this effectively limits the number of clang instances spawned.
76+
pool = concurrent.futures.ThreadPoolExecutor(max_workers=threads)
6977
obj = worker_class(*args, **kwargs)
7078

7179
def make_ondone(msgid):
@@ -101,7 +109,7 @@ def _run(*args, **kwargs):
101109
raise e
102110

103111

104-
def _make_stub(cls: 'type[Worker]', *args, **kwargs):
112+
def _make_stub(cls: 'type[Worker]', pool_threads: int, *args, **kwargs):
105113

106114
class _Stub():
107115
"""Client stub to a worker hosted by a process."""
@@ -118,6 +126,7 @@ def __init__(self):
118126
worker_class=cls,
119127
in_q=self._send,
120128
out_q=self._receive,
129+
threads=pool_threads,
121130
*args,
122131
**kwargs))
123132
# lock for the msgid -> reply future map. The map will be set to None
@@ -208,16 +217,25 @@ def __dir__(self):
208217
return _Stub()
209218

210219

220+
@gin.configurable
211221
class LocalWorkerPool(AbstractContextManager):
212222
"""A pool of workers hosted on the local machines, each in its own process."""
213223

214-
def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
224+
def __init__(self,
225+
worker_class: 'type[Worker]',
226+
count: Optional[int],
227+
*args,
228+
pool_threads: int = 1,
215229
**kwargs):
216230
if not count:
217231
count = multiprocessing.cpu_count()
218232
self._stubs = [
219-
_make_stub(worker_class, *args, **kwargs) for _ in range(count)
233+
_make_stub(worker_class, pool_threads, *args, **kwargs)
234+
for _ in range(count // pool_threads)
220235
]
236+
# Make sure there's always `count` worker threads, not a rounded `count`
237+
if (remainder := count % pool_threads) != 0:
238+
self._stubs.append(_make_stub(worker_class, remainder, *args, **kwargs))
221239

222240
def __enter__(self):
223241
return self._stubs

0 commit comments

Comments
 (0)