Skip to content

Commit 3e6ff9d

Browse files
committed
Some refactoring
1 parent f69b81b commit 3e6ff9d

File tree

6 files changed

+127
-110
lines changed

6 files changed

+127
-110
lines changed

mypy/build.py

Lines changed: 69 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,25 @@
2929
import types
3030
from collections.abc import Iterator, Mapping, Sequence, Set as AbstractSet
3131
from heapq import heappop, heappush
32-
from select import select
3332
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Final, NoReturn, TextIO, TypedDict
3433
from typing_extensions import TypeAlias as _TypeAlias
3534

3635
from librt.internal import cache_version
3736

3837
import mypy.semanal_main
39-
from mypy.cache import CACHE_VERSION, CacheMeta, ReadBuffer, WriteBuffer, read_json, write_json
38+
from mypy.cache import CACHE_VERSION, CacheMeta, ReadBuffer, WriteBuffer
4039
from mypy.checker import TypeChecker
40+
from mypy.defaults import (
41+
WORKER_CONNECTION_TIMEOUT,
42+
WORKER_DONE_TIMEOUT,
43+
WORKER_START_INTERVAL,
44+
WORKER_START_TIMEOUT,
45+
)
4146
from mypy.error_formatter import OUTPUT_CHOICES, ErrorFormatter
4247
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
4348
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
4449
from mypy.indirection import TypeIndirectionVisitor
45-
from mypy.ipc import BadStatus, IPCBase, IPCClient, read_status
50+
from mypy.ipc import BadStatus, IPCClient, read_status, ready_to_read, receive, send
4651
from mypy.messages import MessageBuilder
4752
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable
4853
from mypy.partially_defined import PossiblyUndefinedVariableVisitor
@@ -172,66 +177,13 @@ def __init__(self, manager: BuildManager, graph: Graph) -> None:
172177
self.errors: list[str] = [] # Filled in by build if desired
173178

174179

175-
def receive(connection: IPCBase) -> dict[str, Any]:
176-
"""Receive single JSON data frame from a connection.
177-
178-
Raise OSError if the data received is not valid JSON or if it is
179-
not a dict.
180-
"""
181-
bdata = connection.read_bytes()
182-
if not bdata:
183-
raise OSError("No data received")
184-
try:
185-
buf = ReadBuffer(bdata)
186-
data = read_json(buf)
187-
except Exception as e:
188-
raise OSError("Data received is not valid JSON dict") from e
189-
return data
190-
191-
192-
def send(connection: IPCBase, data: dict[str, Any]) -> None:
193-
"""Send data to a connection encoded and framed.
194-
195-
The data must be JSON-serializable. We assume that a single send call is a
196-
single frame to be sent on the connect.
197-
"""
198-
buf = WriteBuffer()
199-
write_json(buf, data)
200-
connection.write_bytes(buf.getvalue())
201-
202-
203180
class WorkerClient:
204-
def __init__(self, idx: int, conn: IPCClient, proc: subprocess.Popen[bytes]) -> None:
205-
self.idx = idx
181+
def __init__(self, status_file: str, conn: IPCClient, proc: subprocess.Popen[bytes]) -> None:
182+
self.status_file = status_file
206183
self.conn = conn
207184
self.proc = proc
208185

209186

210-
def wait_for_worker(status_file: str, timeout: float = 5.0) -> tuple[int, str]:
211-
"""Wait until the worker is up.
212-
213-
Exit if it doesn't happen within the timeout.
214-
"""
215-
endtime = time.time() + timeout
216-
while time.time() < endtime:
217-
try:
218-
data = read_status(status_file)
219-
except BadStatus:
220-
# If the file isn't there yet, retry later.
221-
time.sleep(0.05)
222-
continue
223-
try:
224-
pid = data["pid"]
225-
connection_name = data["connection_name"]
226-
assert isinstance(pid, int) and isinstance(connection_name, str)
227-
return pid, connection_name
228-
except Exception:
229-
# If the file's content is bogus or the process is dead, fail.
230-
pass
231-
print("Worker process failed to start")
232-
sys.exit(2)
233-
234-
235187
def start_worker(options_data: str, idx: int, env: Mapping[str, str]) -> subprocess.Popen[bytes]:
236188
status_file = f".mypy_worker.{idx}.json"
237189
if os.path.isfile(status_file):
@@ -246,11 +198,31 @@ def start_worker(options_data: str, idx: int, env: Mapping[str, str]) -> subproc
246198
return subprocess.Popen(command, env=env)
247199

248200

249-
def get_worker(idx: int, proc: subprocess.Popen[bytes]) -> WorkerClient:
201+
def wait_for_worker(idx: int, proc: subprocess.Popen[bytes]) -> WorkerClient:
202+
"""Wait until the worker is up.
203+
204+
Exit if it doesn't happen within the timeout.
205+
"""
250206
status_file = f".mypy_worker.{idx}.json"
251-
pid, connection_name = wait_for_worker(status_file)
252-
assert pid == proc.pid
253-
return WorkerClient(idx, IPCClient(connection_name, 10), proc)
207+
endtime = time.time() + WORKER_START_TIMEOUT
208+
while time.time() < endtime:
209+
try:
210+
data = read_status(status_file)
211+
except BadStatus:
212+
# If the file isn't there yet, retry later.
213+
time.sleep(WORKER_START_INTERVAL)
214+
continue
215+
try:
216+
pid, connection_name = data["pid"], data["connection_name"]
217+
assert isinstance(pid, int) and isinstance(connection_name, str)
218+
assert pid == proc.pid
219+
return WorkerClient(
220+
status_file, IPCClient(connection_name, WORKER_CONNECTION_TIMEOUT), proc
221+
)
222+
except Exception:
223+
break
224+
print("Worker process failed to start")
225+
sys.exit(2)
254226

255227

256228
def build_error(msg: str) -> NoReturn:
@@ -308,14 +280,15 @@ def default_flush_errors(
308280
extra_plugins = extra_plugins or []
309281

310282
workers = []
311-
procs = []
312283
if options.num_workers > 0:
313284
pickled_options = pickle.dumps(options.snapshot())
314285
options_data = base64.b64encode(pickled_options).decode()
315-
for i in range(options.num_workers):
316-
procs.append(start_worker(options_data, i, worker_env or os.environ))
317-
for i, proc in enumerate(procs):
318-
workers.append(get_worker(i, proc))
286+
procs = [
287+
start_worker(options_data, idx, worker_env or os.environ)
288+
for idx in range(options.num_workers)
289+
]
290+
for idx, proc in enumerate(procs):
291+
workers.append(wait_for_worker(idx, proc))
319292

320293
for worker in workers:
321294
source_tuples = [(s.path, s.module, s.text, s.base_dir, s.followed) for s in sources]
@@ -353,9 +326,8 @@ def default_flush_errors(
353326
for worker in workers:
354327
worker.conn.close()
355328
worker.proc.wait()
356-
status_file = f".mypy_worker.{worker.idx}.json"
357-
if os.path.isfile(status_file):
358-
os.unlink(status_file)
329+
if os.path.isfile(worker.status_file):
330+
os.unlink(worker.status_file)
359331

360332

361333
def build_inner(
@@ -1037,47 +1009,48 @@ def stats_summary(self) -> Mapping[str, object]:
10371009
return self.stats
10381010

10391011
def submit(self, sccs: list[SCC]) -> None:
1040-
"""Submit a stale SCC for processing in current process."""
1012+
"""Submit a stale SCC for processing in current process or parallel workers."""
10411013
if self.workers:
1014+
self.submit_to_workers(sccs)
1015+
else:
1016+
self.scc_queue.extend([(0, 0, scc) for scc in sccs])
1017+
1018+
def submit_to_workers(self, sccs: list[SCC] | None = None) -> None:
1019+
if sccs is not None:
10421020
for scc in sccs:
10431021
heappush(self.scc_queue, (-scc.size_hint, self.queue_order, scc))
10441022
self.queue_order += 1
1045-
else:
1046-
self.scc_queue.extend([(0, 0, scc) for scc in sccs])
10471023
while self.scc_queue and self.free_workers:
1048-
worker = self.free_workers.pop()
1049-
if self.workers:
1050-
_, _, scc = heappop(self.scc_queue)
1051-
else:
1052-
_, _, scc = self.scc_queue.pop(0)
1053-
send(self.workers[worker].conn, {"scc_id": scc.id})
1024+
idx = self.free_workers.pop()
1025+
_, _, scc = heappop(self.scc_queue)
1026+
send(self.workers[idx].conn, {"scc_id": scc.id})
10541027

10551028
def wait_for_done(
10561029
self, graph: Graph
10571030
) -> tuple[list[SCC], bool, dict[str, tuple[str, list[str]]]]:
1058-
"""Wait for a stale SCC processing (in process) to finish.
1031+
"""Wait for a stale SCC processing to finish.
10591032
1060-
Return next processed SCC and whether we have more in the queue.
1061-
This emulates the API we will have for parallel processing
1062-
in multiple worker processes.
1033+
Return a tuple three items:
1034+
* processed SCCs
1035+
* whether we have more in the queue
1036+
* new interface hash and list of errors for each module
1037+
The last item is only used for parallel processing.
10631038
"""
1064-
if not self.workers:
1065-
if not self.scc_queue:
1066-
return [], False, {}
1067-
_, _, next_scc = self.scc_queue.pop(0)
1068-
process_stale_scc(graph, next_scc, self)
1069-
return [next_scc], bool(self.scc_queue), {}
1039+
if self.workers:
1040+
return self.wait_for_done_workers()
1041+
if not self.scc_queue:
1042+
return [], False, {}
1043+
_, _, next_scc = self.scc_queue.pop(0)
1044+
process_stale_scc(graph, next_scc, self)
1045+
return [next_scc], bool(self.scc_queue), {}
10701046

1047+
def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, list[str]]]]:
10711048
if not self.scc_queue and len(self.free_workers) == len(self.workers):
10721049
return [], False, {}
10731050

1074-
# TODO: don't select from free workers.
1075-
conns = [w.conn.connection for w in self.workers]
1076-
ready, _, _ = select(conns, [], [], 100)
10771051
done_sccs = []
10781052
results = {}
1079-
for r in ready:
1080-
idx = conns.index(r)
1053+
for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT):
10811054
data = receive(self.workers[idx].conn)
10821055
self.free_workers.add(idx)
10831056
scc_id = data["scc_id"]
@@ -1088,7 +1061,7 @@ def wait_for_done(
10881061
)
10891062
results.update({k: tuple(v) for k, v in data["result"].items()})
10901063
done_sccs.append(self.scc_by_id[scc_id])
1091-
self.submit([]) # advance after some workers are free.
1064+
self.submit_to_workers() # advance after some workers are free.
10921065
return (
10931066
done_sccs,
10941067
bool(self.scc_queue) or len(self.free_workers) < len(self.workers),
@@ -3534,7 +3507,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
35343507
data = receive(worker.conn)
35353508
assert data["status"] == "ok"
35363509

3537-
manager.free_workers = {w.idx for w in manager.workers}
3510+
manager.free_workers = set(range(manager.options.num_workers))
35383511

35393512
# Prime the ready list with leaf SCCs (that have no dependencies).
35403513
ready = []

mypy/build_worker/worker.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,11 @@
1212
from typing import NamedTuple
1313

1414
from mypy import util
15-
from mypy.build import (
16-
SCC,
17-
BuildManager,
18-
load_graph,
19-
load_plugins,
20-
process_stale_scc,
21-
receive,
22-
send,
23-
)
15+
from mypy.build import SCC, BuildManager, load_graph, load_plugins, process_stale_scc
16+
from mypy.defaults import RECURSION_LIMIT
2417
from mypy.errors import CompileError, Errors, report_internal_error
2518
from mypy.fscache import FileSystemCache
26-
from mypy.ipc import IPCServer
27-
from mypy.main import RECURSION_LIMIT
19+
from mypy.ipc import IPCServer, receive, send
2820
from mypy.modulefinder import BuildSource, BuildSourceSet, compute_search_paths
2921
from mypy.options import Options
3022
from mypy.util import read_py_file

mypy/defaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,10 @@
4242
# Threshold after which we sometimes filter out most errors to avoid very
4343
# verbose output. The default is to show all errors.
4444
MANY_ERRORS_THRESHOLD: Final = -1
45+
46+
RECURSION_LIMIT: Final = 2**14
47+
48+
WORKER_START_INTERVAL: Final = 0.03
49+
WORKER_START_TIMEOUT: Final = 3
50+
WORKER_CONNECTION_TIMEOUT: Final = 10
51+
WORKER_DONE_TIMEOUT: Final = 600

mypy/dmypy/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from collections.abc import Mapping
1818
from typing import Any, Callable, NoReturn
1919

20+
from mypy.defaults import RECURSION_LIMIT
2021
from mypy.dmypy_os import alive, kill
2122
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive, send
2223
from mypy.ipc import BadStatus, IPCClient, IPCException, read_status
23-
from mypy.main import RECURSION_LIMIT
2424
from mypy.util import check_python_version, get_terminal_width, should_force_color
2525
from mypy.version import __version__
2626

mypy/ipc.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,13 @@
1313
import shutil
1414
import sys
1515
import tempfile
16+
from select import select
1617
from types import TracebackType
17-
from typing import Callable, Final
18+
from typing import Any, Callable, Final
19+
20+
from librt.internal import ReadBuffer, WriteBuffer
21+
22+
from mypy.cache import read_json, write_json
1823

1924
if sys.platform == "win32":
2025
# This may be private, but it is needed for IPC on Windows, and is basically stable
@@ -346,3 +351,43 @@ def read_status(status_file: str) -> dict[str, object]:
346351
if not isinstance(data, dict):
347352
raise BadStatus("Invalid status file (not a dict)")
348353
return data
354+
355+
356+
def ready_to_read(conns: list[IPCClient], timeout: float | None = None) -> list[int]:
357+
"""Wait until some connections are readable.
358+
359+
Return index of each readable connection in the original list.
360+
"""
361+
# TODO: add Windows support for this.
362+
assert sys.platform != "win32"
363+
connections = [conn.connection for conn in conns]
364+
ready, _, _ = select(connections, [], [], timeout)
365+
return [connections.index(r) for r in ready]
366+
367+
368+
def receive(connection: IPCBase) -> dict[str, Any]:
369+
"""Receive single JSON data frame from a connection.
370+
371+
Raise OSError if the data received is not valid JSON or if it is
372+
not a dict.
373+
"""
374+
bdata = connection.read_bytes()
375+
if not bdata:
376+
raise OSError("No data received")
377+
try:
378+
buf = ReadBuffer(bdata)
379+
data = read_json(buf)
380+
except Exception as e:
381+
raise OSError("Data received is not valid JSON dict") from e
382+
return data
383+
384+
385+
def send(connection: IPCBase, data: dict[str, Any]) -> None:
386+
"""Send data to a connection encoded and framed.
387+
388+
The data must be JSON-serializable. We assume that a single send call is a
389+
single frame to be sent on the connect.
390+
"""
391+
buf = WriteBuffer()
392+
write_json(buf, data)
393+
connection.write_bytes(buf.getvalue())

mypy/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
parse_version,
2121
validate_package_allow_list,
2222
)
23+
from mypy.defaults import RECURSION_LIMIT
2324
from mypy.error_formatter import OUTPUT_CHOICES
2425
from mypy.errors import CompileError
2526
from mypy.find_sources import InvalidSourceList, create_source_list
@@ -42,7 +43,6 @@
4243

4344
orig_stat: Final = os.stat
4445
MEM_PROFILE: Final = False # If True, dump memory profile
45-
RECURSION_LIMIT: Final = 2**14
4646

4747

4848
def stat_proxy(path: str) -> os.stat_result:

0 commit comments

Comments
 (0)