Skip to content

Commit 795dde3

Browse files
authored
Use proper fixed format for parallel type checking IPC (#20565)
This is another small follow-up for #20280 As promised I am switching from an ad-hoc "binary JSON" format for IPC to a proper fixed format. The performance win is likely negligible, but it looks better and can be "statically typed". Note that since we don't have any backwards compatibility requirements nor need for messages to be decoded by a generic reader, I am using a simpler, more compact serialization format that incremental cache.
1 parent e7a19c8 commit 795dde3

File tree

3 files changed

+254
-69
lines changed

3 files changed

+254
-69
lines changed

mypy/build.py

Lines changed: 211 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,40 @@
4040
TypedDict,
4141
)
4242

43-
from librt.internal import cache_version
43+
from librt.internal import (
44+
cache_version,
45+
read_bool,
46+
read_int as read_int_bare,
47+
read_str as read_str_bare,
48+
read_tag,
49+
write_bool,
50+
write_int as write_int_bare,
51+
write_str as write_str_bare,
52+
write_tag,
53+
)
4454

4555
import mypy.semanal_main
4656
from mypy.cache import (
4757
CACHE_VERSION,
58+
DICT_STR_GEN,
59+
LITERAL_NONE,
4860
CacheMeta,
4961
ReadBuffer,
5062
SerializedError,
63+
Tag,
5164
WriteBuffer,
52-
write_json,
65+
read_int,
66+
read_int_list,
67+
read_int_opt,
68+
read_str,
69+
read_str_list,
70+
read_str_opt,
71+
write_int,
72+
write_int_list,
73+
write_int_opt,
74+
write_str,
75+
write_str_list,
76+
write_str_opt,
5377
)
5478
from mypy.checker import TypeChecker
5579
from mypy.defaults import (
@@ -62,7 +86,7 @@
6286
from mypy.errors import CompileError, ErrorInfo, Errors, ErrorTuple, report_internal_error
6387
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
6488
from mypy.indirection import TypeIndirectionVisitor
65-
from mypy.ipc import BadStatus, IPCClient, read_status, ready_to_read, receive, send
89+
from mypy.ipc import BadStatus, IPCClient, IPCMessage, read_status, ready_to_read, receive, send
6690
from mypy.messages import MessageBuilder
6791
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable
6892
from mypy.partially_defined import PossiblyUndefinedVariableVisitor
@@ -310,7 +334,10 @@ def default_flush_errors(
310334
WorkerClient(f".mypy_worker.{idx}.json", options_data, worker_env or os.environ)
311335
for idx in range(options.num_workers)
312336
]
313-
sources_data = sources_to_bytes(sources)
337+
sources_message = SourcesDataMessage(sources=sources)
338+
buf = WriteBuffer()
339+
sources_message.write(buf)
340+
sources_data = buf.getvalue()
314341
for worker in workers:
315342
# Start loading graph in each worker as soon as it is up.
316343
worker.connect()
@@ -342,7 +369,7 @@ def default_flush_errors(
342369
finally:
343370
for worker in workers:
344371
try:
345-
send(worker.conn, {"final": True})
372+
send(worker.conn, SccRequestMessage(scc_id=None))
346373
except OSError:
347374
pass
348375
for worker in workers:
@@ -1049,7 +1076,7 @@ def submit_to_workers(self, sccs: list[SCC] | None = None) -> None:
10491076
while self.scc_queue and self.free_workers:
10501077
idx = self.free_workers.pop()
10511078
_, _, scc = heappop(self.scc_queue)
1052-
send(self.workers[idx].conn, {"scc_id": scc.id})
1079+
send(self.workers[idx].conn, SccRequestMessage(scc_id=scc.id))
10531080

10541081
def wait_for_done(
10551082
self, graph: Graph
@@ -1077,15 +1104,13 @@ def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, l
10771104
done_sccs = []
10781105
results = {}
10791106
for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT):
1080-
data = receive(self.workers[idx].conn)
1107+
data = SccResponseMessage.read(receive(self.workers[idx].conn))
10811108
self.free_workers.add(idx)
1082-
scc_id = data["scc_id"]
1083-
if "blocker" in data:
1084-
blocker = data["blocker"]
1085-
raise CompileError(
1086-
blocker["messages"], blocker["use_stdout"], blocker["module_with_blocker"]
1087-
)
1088-
results.update({k: tuple(v) for k, v in data["result"].items()})
1109+
scc_id = data.scc_id
1110+
if data.blocker is not None:
1111+
raise data.blocker
1112+
assert data.result is not None
1113+
results.update(data.result)
10891114
done_sccs.append(self.scc_by_id[scc_id])
10901115
self.submit_to_workers() # advance after some workers are free.
10911116
return (
@@ -3558,14 +3583,15 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
35583583
manager.top_order = [scc.id for scc in sccs]
35593584

35603585
# Broadcast SCC structure to the parallel workers, since they don't compute it.
3561-
sccs_data = sccs_to_bytes(sccs)
3586+
sccs_message = SccsDataMessage(sccs=sccs)
3587+
buf = WriteBuffer()
3588+
sccs_message.write(buf)
3589+
sccs_data = buf.getvalue()
35623590
for worker in manager.workers:
3563-
data = receive(worker.conn)
3564-
assert data["status"] == "ok"
3591+
AckMessage.read(receive(worker.conn))
35653592
worker.conn.write_bytes(sccs_data)
35663593
for worker in manager.workers:
3567-
data = receive(worker.conn)
3568-
assert data["status"] == "ok"
3594+
AckMessage.read(receive(worker.conn))
35693595

35703596
manager.free_workers = set(range(manager.options.num_workers))
35713597

@@ -3944,20 +3970,6 @@ def write_undocumented_ref_info(
39443970
metastore.write(ref_info_file, json_dumps(deps_json))
39453971

39463972

3947-
def sources_to_bytes(sources: list[BuildSource]) -> bytes:
3948-
source_tuples = [(s.path, s.module, s.text, s.base_dir, s.followed) for s in sources]
3949-
buf = WriteBuffer()
3950-
write_json(buf, {"sources": source_tuples})
3951-
return buf.getvalue()
3952-
3953-
3954-
def sccs_to_bytes(sccs: list[SCC]) -> bytes:
3955-
scc_tuples = [(list(scc.mod_ids), scc.id, list(scc.deps)) for scc in sccs]
3956-
buf = WriteBuffer()
3957-
write_json(buf, {"sccs": scc_tuples})
3958-
return buf.getvalue()
3959-
3960-
39613973
def serialize_codes(errs: list[ErrorTuple]) -> list[SerializedError]:
39623974
return [
39633975
(path, line, column, end_line, end_column, severity, message, code.code if code else None)
@@ -3979,3 +3991,169 @@ def deserialize_codes(errs: list[SerializedError]) -> list[ErrorTuple]:
39793991
)
39803992
for path, line, column, end_line, end_column, severity, message, code in errs
39813993
]
3994+
3995+
3996+
# The IPC message classes and tags for communication with build workers are
3997+
# in this file to avoid import cycles.
3998+
# Note that we use a more compact fixed serialization format than in cache.py.
3999+
# This is because the messages don't need to read by a generic tool, nor there
4000+
# is any need for backwards compatibility. We still reuse some elements from
4001+
# cache.py for convenience, and also some conventions (like using bare ints
4002+
# to specify object size).
4003+
# Note that we can use tags overlapping with cache.py, since they should never
4004+
# appear on the same context.
4005+
ACK_MESSAGE: Final[Tag] = 101
4006+
SCC_REQUEST_MESSAGE: Final[Tag] = 102
4007+
SCC_RESPONSE_MESSAGE: Final[Tag] = 103
4008+
SOURCES_DATA_MESSAGE: Final[Tag] = 104
4009+
SCCS_DATA_MESSAGE: Final[Tag] = 105
4010+
4011+
4012+
class AckMessage(IPCMessage):
4013+
"""An empty message used primarily for synchronization."""
4014+
4015+
@classmethod
4016+
def read(cls, buf: ReadBuffer) -> AckMessage:
4017+
assert read_tag(buf) == ACK_MESSAGE
4018+
return AckMessage()
4019+
4020+
def write(self, buf: WriteBuffer) -> None:
4021+
write_tag(buf, ACK_MESSAGE)
4022+
4023+
4024+
class SccRequestMessage(IPCMessage):
4025+
"""
4026+
A message representing a request to type check an SCC.
4027+
4028+
If scc_id is None, then it means that the coordinator requested a shutdown.
4029+
"""
4030+
4031+
def __init__(self, *, scc_id: int | None) -> None:
4032+
self.scc_id = scc_id
4033+
4034+
@classmethod
4035+
def read(cls, buf: ReadBuffer) -> SccRequestMessage:
4036+
assert read_tag(buf) == SCC_REQUEST_MESSAGE
4037+
return SccRequestMessage(scc_id=read_int_opt(buf))
4038+
4039+
def write(self, buf: WriteBuffer) -> None:
4040+
write_tag(buf, SCC_REQUEST_MESSAGE)
4041+
write_int_opt(buf, self.scc_id)
4042+
4043+
4044+
class SccResponseMessage(IPCMessage):
4045+
"""
4046+
A message representing a result of type checking an SCC.
4047+
4048+
Only one of `result` or `blocker` can be non-None. The latter means there was
4049+
a blocking error while type checking the SCC.
4050+
"""
4051+
4052+
def __init__(
4053+
self,
4054+
*,
4055+
scc_id: int,
4056+
result: dict[str, tuple[str, list[str]]] | None = None,
4057+
blocker: CompileError | None = None,
4058+
) -> None:
4059+
if result is not None:
4060+
assert blocker is None
4061+
if blocker is not None:
4062+
assert result is None
4063+
self.scc_id = scc_id
4064+
self.result = result
4065+
self.blocker = blocker
4066+
4067+
@classmethod
4068+
def read(cls, buf: ReadBuffer) -> SccResponseMessage:
4069+
assert read_tag(buf) == SCC_RESPONSE_MESSAGE
4070+
scc_id = read_int(buf)
4071+
tag = read_tag(buf)
4072+
if tag == LITERAL_NONE:
4073+
return SccResponseMessage(
4074+
scc_id=scc_id,
4075+
blocker=CompileError(read_str_list(buf), read_bool(buf), read_str_opt(buf)),
4076+
)
4077+
else:
4078+
assert tag == DICT_STR_GEN
4079+
return SccResponseMessage(
4080+
scc_id=scc_id,
4081+
result={
4082+
read_str_bare(buf): (read_str(buf), read_str_list(buf))
4083+
for _ in range(read_int_bare(buf))
4084+
},
4085+
)
4086+
4087+
def write(self, buf: WriteBuffer) -> None:
4088+
write_tag(buf, SCC_RESPONSE_MESSAGE)
4089+
write_int(buf, self.scc_id)
4090+
if self.result is None:
4091+
assert self.blocker is not None
4092+
write_tag(buf, LITERAL_NONE)
4093+
write_str_list(buf, self.blocker.messages)
4094+
write_bool(buf, self.blocker.use_stdout)
4095+
write_str_opt(buf, self.blocker.module_with_blocker)
4096+
else:
4097+
write_tag(buf, DICT_STR_GEN)
4098+
write_int_bare(buf, len(self.result))
4099+
for mod_id in sorted(self.result):
4100+
write_str_bare(buf, mod_id)
4101+
hex_hash, errs = self.result[mod_id]
4102+
write_str(buf, hex_hash)
4103+
write_str_list(buf, errs)
4104+
4105+
4106+
class SourcesDataMessage(IPCMessage):
4107+
"""A message wrapping a list of build sources."""
4108+
4109+
def __init__(self, *, sources: list[BuildSource]) -> None:
4110+
self.sources = sources
4111+
4112+
@classmethod
4113+
def read(cls, buf: ReadBuffer) -> SourcesDataMessage:
4114+
assert read_tag(buf) == SOURCES_DATA_MESSAGE
4115+
sources = [
4116+
BuildSource(
4117+
read_str_opt(buf),
4118+
read_str_opt(buf),
4119+
read_str_opt(buf),
4120+
read_str_opt(buf),
4121+
read_bool(buf),
4122+
)
4123+
for _ in range(read_int_bare(buf))
4124+
]
4125+
return SourcesDataMessage(sources=sources)
4126+
4127+
def write(self, buf: WriteBuffer) -> None:
4128+
write_tag(buf, SOURCES_DATA_MESSAGE)
4129+
write_int_bare(buf, len(self.sources))
4130+
for bs in self.sources:
4131+
write_str_opt(buf, bs.path)
4132+
write_str_opt(buf, bs.module)
4133+
write_str_opt(buf, bs.text)
4134+
write_str_opt(buf, bs.base_dir)
4135+
write_bool(buf, bs.followed)
4136+
4137+
4138+
class SccsDataMessage(IPCMessage):
4139+
"""A message wrapping the SCC structure computed by the coordinator."""
4140+
4141+
def __init__(self, *, sccs: list[SCC]) -> None:
4142+
self.sccs = sccs
4143+
4144+
@classmethod
4145+
def read(cls, buf: ReadBuffer) -> SccsDataMessage:
4146+
assert read_tag(buf) == SCCS_DATA_MESSAGE
4147+
sccs = [
4148+
SCC(set(read_str_list(buf)), read_int(buf), read_int_list(buf))
4149+
for _ in range(read_int_bare(buf))
4150+
]
4151+
return SccsDataMessage(sccs=sccs)
4152+
4153+
def write(self, buf: WriteBuffer) -> None:
4154+
write_tag(buf, SCCS_DATA_MESSAGE)
4155+
write_int_bare(buf, len(self.sccs))
4156+
for scc in self.sccs:
4157+
write_str_list(buf, sorted(scc.mod_ids))
4158+
write_int(buf, scc.id)
4159+
write_int_list(buf, sorted(scc.deps))

0 commit comments

Comments
 (0)