Skip to content

Commit 0c9fd68

Browse files
authored
LocalWorker: Fix nr threads to 1 (#81)
1 parent fb54e5d commit 0c9fd68

File tree

1 file changed

+8
-16
lines changed

1 file changed

+8
-16
lines changed

compiler_opt/distributed/local/local_worker_manager.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import queue # pylint: disable=unused-import
3636
import threading
3737

38-
import gin
3938
from absl import logging
4039
# pylint: disable=unused-import
4140
from compiler_opt.distributed.worker import Worker
@@ -61,7 +60,7 @@ class TaskResult:
6160

6261

6362
def _run_impl(in_q: 'queue.Queue[Task]', out_q: 'queue.Queue[TaskResult]',
64-
worker_class: 'type[Worker]', threads: int, *args, **kwargs):
63+
worker_class: 'type[Worker]', *args, **kwargs):
6564
"""Worker process entrypoint."""
6665
# Note: the out_q is typed as taking only TaskResult objects, not
6766
# Optional[TaskResult], despite that being the type it is used on the Stub
@@ -73,7 +72,7 @@ def _run_impl(in_q: 'queue.Queue[Task]', out_q: 'queue.Queue[TaskResult]',
7372
# process near-immediately. `threads` only controls how many threads are
7473
# spawned at a time which execute given tasks. In the typical clang-spawning
7574
# jobs, this effectively limits the number of clang instances spawned.
76-
pool = concurrent.futures.ThreadPoolExecutor(max_workers=threads)
75+
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
7776
obj = worker_class(*args, **kwargs)
7877

7978
def make_ondone(msgid):
@@ -109,7 +108,7 @@ def _run(*args, **kwargs):
109108
raise e
110109

111110

112-
def _make_stub(cls: 'type[Worker]', pool_threads: int, *args, **kwargs):
111+
def _make_stub(cls: 'type[Worker]', *args, **kwargs):
113112

114113
class _Stub():
115114
"""Client stub to a worker hosted by a process."""
@@ -120,13 +119,15 @@ def __init__(self):
120119
multiprocessing.get_context().Queue()
121120

122121
# this is the process hosting one worker instance.
122+
# we set aside 1 thread to coordinate running jobs, and the main thread
123+
# to handle high priority requests. The expectation is that the user
124+
# achieves concurrency through multiprocessing, not multithreading.
123125
self._process = multiprocessing.Process(
124126
target=functools.partial(
125127
_run,
126128
worker_class=cls,
127129
in_q=self._send,
128130
out_q=self._receive,
129-
threads=pool_threads,
130131
*args,
131132
**kwargs))
132133
# lock for the msgid -> reply future map. The map will be set to None
@@ -217,25 +218,16 @@ def __dir__(self):
217218
return _Stub()
218219

219220

220-
@gin.configurable
221221
class LocalWorkerPool(AbstractContextManager):
222222
"""A pool of workers hosted on the local machines, each in its own process."""
223223

224-
def __init__(self,
225-
worker_class: 'type[Worker]',
226-
count: Optional[int],
227-
*args,
228-
pool_threads: int = 1,
224+
def __init__(self, worker_class: 'type[Worker]', count: Optional[int], *args,
229225
**kwargs):
230226
if not count:
231227
count = multiprocessing.cpu_count()
232228
self._stubs = [
233-
_make_stub(worker_class, pool_threads, *args, **kwargs)
234-
for _ in range(count // pool_threads)
229+
_make_stub(worker_class, *args, **kwargs) for _ in range(count)
235230
]
236-
# Make sure there's always `count` worker threads, not a rounded `count`
237-
if (remainder := count % pool_threads) != 0:
238-
self._stubs.append(_make_stub(worker_class, remainder, *args, **kwargs))
239231

240232
def __enter__(self):
241233
return self._stubs

0 commit comments

Comments
 (0)