27
27
local thread pool, or, if the task is 'urgent', it executes it promptly.
28
28
"""
29
29
import concurrent .futures
30
+ import sys
30
31
import cloudpickle
31
32
import dataclasses
32
33
import functools
33
34
import multiprocessing
34
35
import psutil
35
36
import threading
36
37
37
- from absl import logging
38
+ from absl import flags , logging
38
39
# pylint: disable=unused-import
39
40
from compiler_opt .distributed import worker
40
41
@@ -67,8 +68,8 @@ def _get_context():
67
68
SerializedClass = bytes
68
69
69
70
70
- def _run_impl (pipe : connection .Connection , worker_class : SerializedClass , * args ,
71
- ** kwargs ):
71
+ def _run_impl (pipe : connection .Connection , worker_class : SerializedClass ,
72
+ parse_argv : bool , * args , ** kwargs ):
72
73
"""Worker process entrypoint."""
73
74
74
75
# A setting of 1 does not inhibit the while loop below from running since
@@ -77,7 +78,12 @@ def _run_impl(pipe: connection.Connection, worker_class: SerializedClass, *args,
77
78
# spawned at a time which execute given tasks. In the typical clang-spawning
78
79
# jobs, this effectively limits the number of clang instances spawned.
79
80
pool = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
80
- obj = cloudpickle .loads (worker_class )(* args , ** kwargs )
81
+ # unpickle the type, which triggers all the necessary imports. This registers
82
+ # abseil flags, which can then be parsed.
83
+ cls = cloudpickle .loads (worker_class )
84
+ if parse_argv :
85
+ flags .FLAGS (sys .argv , known_only = True )
86
+ obj = cls (* args , ** kwargs )
81
87
82
88
# Pipes are not thread safe
83
89
pipe_lock = threading .Lock ()
@@ -111,16 +117,16 @@ def on_done(f: concurrent.futures.Future):
111
117
pool .submit (application ).add_done_callback (make_ondone (task .msgid ))
112
118
113
119
114
- def _run (pipe : connection .Connection , worker_class : SerializedClass , * args ,
115
- ** kwargs ):
120
+ def _run (pipe : connection .Connection , worker_class : SerializedClass ,
121
+ parse_argv : bool , * args , ** kwargs ):
116
122
try :
117
- _run_impl (pipe , worker_class , * args , ** kwargs )
123
+ _run_impl (pipe , worker_class , parse_argv , * args , ** kwargs )
118
124
except BaseException as e :
119
125
logging .error (e )
120
126
raise e
121
127
122
128
123
- def _make_stub (cls : 'type[worker.Worker]' , * args , ** kwargs ):
129
+ def _make_stub (cls : 'type[worker.Worker]' , parse_argv : bool , * args , ** kwargs ):
124
130
125
131
class _Stub :
126
132
"""Client stub to a worker hosted by a process."""
@@ -135,8 +141,8 @@ def __init__(self):
135
141
# to handle high priority requests. The expectation is that the user
136
142
# achieves concurrency through multiprocessing, not multithreading.
137
143
self ._process = _get_context ().Process (
138
- target = functools .partial (_run , child_pipe , cloudpickle .dumps (cls ), *
139
- args , ** kwargs ))
144
+ target = functools .partial (_run , child_pipe , cloudpickle .dumps (cls ),
145
+ parse_argv , * args , ** kwargs ))
140
146
# lock for the msgid -> reply future map. The map will be set to None
141
147
# when we stop.
142
148
self ._lock = threading .Lock ()
@@ -248,13 +254,16 @@ def __dir__(self):
248
254
249
255
250
256
def create_local_worker_pool (worker_cls : 'type[worker.Worker]' ,
251
- count : int | None , * args ,
257
+ count : int | None , parse_argv : bool , * args ,
252
258
** kwargs ) -> worker .FixedWorkerPool :
253
259
"""Create a local worker pool for worker_cls."""
254
260
if not count :
255
261
count = _get_context ().cpu_count ()
256
262
final_kwargs = worker .get_full_worker_args (worker_cls , ** kwargs )
257
- stubs = [_make_stub (worker_cls , * args , ** final_kwargs ) for _ in range (count )]
263
+ stubs = [
264
+ _make_stub (worker_cls , parse_argv , * args , ** final_kwargs )
265
+ for _ in range (count )
266
+ ]
258
267
return worker .FixedWorkerPool (workers = stubs , worker_concurrency = 16 )
259
268
260
269
@@ -274,7 +283,8 @@ class LocalWorkerPoolManager(AbstractContextManager):
274
283
275
284
def __init__ (self , worker_class : 'type[worker.Worker]' , count : int | None ,
276
285
* args , ** kwargs ):
277
- self ._pool = create_local_worker_pool (worker_class , count , * args , ** kwargs )
286
+ self ._pool = create_local_worker_pool (worker_class , count , True , * args ,
287
+ ** kwargs )
278
288
279
289
def __enter__ (self ) -> worker .FixedWorkerPool :
280
290
return self ._pool
0 commit comments