36
36
37
37
from absl import logging
38
38
# pylint: disable=unused-import
39
- from compiler_opt .distributed . worker import Worker , FixedWorkerPool
39
+ from compiler_opt .distributed import worker
40
40
41
41
from contextlib import AbstractContextManager
42
42
from multiprocessing import connection
@@ -59,8 +59,8 @@ class TaskResult:
59
59
value : Any
60
60
61
61
62
- def _run_impl (pipe : connection .Connection , worker_class : 'type[Worker]' , * args ,
63
- ** kwargs ):
62
+ def _run_impl (pipe : connection .Connection , worker_class : 'type[worker. Worker]' ,
63
+ * args , * *kwargs ):
64
64
"""Worker process entrypoint."""
65
65
66
66
# A setting of 1 does not inhibit the while loop below from running since
@@ -111,7 +111,7 @@ def _run(*args, **kwargs):
111
111
raise e
112
112
113
113
114
- def _make_stub (cls : 'type[Worker]' , * args , ** kwargs ):
114
+ def _make_stub (cls : 'type[worker. Worker]' , * args , ** kwargs ):
115
115
116
116
class _Stub :
117
117
"""Client stub to a worker hosted by a process."""
@@ -241,16 +241,17 @@ def __dir__(self):
241
241
class LocalWorkerPoolManager (AbstractContextManager ):
242
242
"""A pool of workers hosted on the local machines, each in its own process."""
243
243
244
- def __init__ (self , worker_class : 'type[Worker]' , count : Optional [int ], * args ,
245
- ** kwargs ):
244
+ def __init__ (self , worker_class : 'type[worker. Worker]' , count : Optional [int ],
245
+ * args , * *kwargs ):
246
246
if not count :
247
247
count = multiprocessing .get_context ().cpu_count ()
248
+ final_kwargs = worker .get_full_worker_args (worker_class , kwargs )
248
249
self ._stubs = [
249
- _make_stub (worker_class , * args , ** kwargs ) for _ in range (count )
250
+ _make_stub (worker_class , * args , ** final_kwargs ) for _ in range (count )
250
251
]
251
252
252
- def __enter__ (self ) -> FixedWorkerPool :
253
- return FixedWorkerPool (workers = self ._stubs , worker_concurrency = 10 )
253
+ def __enter__ (self ) -> worker . FixedWorkerPool :
254
+ return worker . FixedWorkerPool (workers = self ._stubs , worker_concurrency = 10 )
254
255
255
256
def __exit__ (self , * args ):
256
257
# first, trigger killing the worker process and exiting of the msg pump,
0 commit comments