diff --git a/mypy/build.py b/mypy/build.py index 71575de9d877..6e49945692b7 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -40,6 +40,7 @@ from typing_extensions import TypeAlias as _TypeAlias import mypy.semanal_main +from mypy.cache import Buffer from mypy.checker import TypeChecker from mypy.error_formatter import OUTPUT_CHOICES, ErrorFormatter from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error @@ -1139,6 +1140,17 @@ def read_deps_cache(manager: BuildManager, graph: Graph) -> dict[str, FgDepMeta] return module_deps_metas +def _load_ff_file(file: str, manager: BuildManager, log_error: str) -> bytes | None: + t0 = time.time() + try: + data = manager.metastore.read(file) + except OSError: + manager.log(log_error + file) + return None + manager.add_stats(metastore_read_time=time.time() - t0) + return data + + def _load_json_file( file: str, manager: BuildManager, log_success: str, log_error: str ) -> dict[str, Any] | None: @@ -1259,7 +1271,11 @@ def get_cache_names(id: str, path: str, options: Options) -> tuple[str, str, str deps_json = None if options.cache_fine_grained: deps_json = prefix + ".deps.json" - return (prefix + ".meta.json", prefix + ".data.json", deps_json) + if options.fixed_format_cache: + data_suffix = ".data.ff" + else: + data_suffix = ".data.json" + return (prefix + ".meta.json", prefix + data_suffix, deps_json) def find_cache_meta(id: str, path: str, manager: BuildManager) -> CacheMeta | None: @@ -1559,8 +1575,13 @@ def write_cache( tree.path = path # Serialize data and analyze interface - data = tree.serialize() - data_bytes = json_dumps(data, manager.options.debug_cache) + if manager.options.fixed_format_cache: + data_io = Buffer() + tree.write(data_io) + data_bytes = data_io.getvalue() + else: + data = tree.serialize() + data_bytes = json_dumps(data, manager.options.debug_cache) interface_hash = hash_digest(data_bytes) plugin_data = manager.plugin.report_config_data(ReportConfigContext(id, path, is_check=False)) @@ -2085,15 +2106,23 @@ def load_tree(self, temporary: bool = False) -> None: self.meta is not None ), "Internal error: this method must be called only for cached modules" - data = _load_json_file( - self.meta.data_json, self.manager, "Load tree ", "Could not load tree: " - ) + data: bytes | dict[str, Any] | None + if self.options.fixed_format_cache: + data = _load_ff_file(self.meta.data_json, self.manager, "Could not load tree: ") + else: + data = _load_json_file( + self.meta.data_json, self.manager, "Load tree ", "Could not load tree: " + ) if data is None: return t0 = time.time() # TODO: Assert data file wasn't changed. - self.tree = MypyFile.deserialize(data) + if isinstance(data, bytes): + data_io = Buffer(data) + self.tree = MypyFile.read(data_io) + else: + self.tree = MypyFile.deserialize(data) t1 = time.time() self.manager.add_stats(deserialize_time=t1 - t0) if not temporary: @@ -2481,7 +2510,11 @@ def write_cache(self) -> None: ): if self.options.debug_serialize: try: - self.tree.serialize() + if self.manager.options.fixed_format_cache: + data = Buffer() + self.tree.write(data) + else: + self.tree.serialize() except Exception: print(f"Error serializing {self.id}", file=self.manager.stdout) raise # Propagate to display traceback diff --git a/mypy/cache.py b/mypy/cache.py new file mode 100644 index 000000000000..49f568c1f3c1 --- /dev/null +++ b/mypy/cache.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Final + +try: + from native_internal import ( + Buffer as Buffer, + read_bool as read_bool, + read_float as read_float, + read_int as read_int, + read_str as read_str, + write_bool as write_bool, + write_float as write_float, + write_int as write_int, + write_str as write_str, + ) +except ImportError: + # TODO: temporary, remove this after we publish mypy-native on PyPI. + if not TYPE_CHECKING: + + class Buffer: + def __init__(self, source: bytes = b"") -> None: + raise NotImplementedError + + def getvalue(self) -> bytes: + raise NotImplementedError + + def read_int(data: Buffer) -> int: + raise NotImplementedError + + def write_int(data: Buffer, value: int) -> None: + raise NotImplementedError + + def read_str(data: Buffer) -> str: + raise NotImplementedError + + def write_str(data: Buffer, value: str) -> None: + raise NotImplementedError + + def read_bool(data: Buffer) -> bool: + raise NotImplementedError + + def write_bool(data: Buffer, value: bool) -> None: + raise NotImplementedError + + def read_float(data: Buffer) -> float: + raise NotImplementedError + + def write_float(data: Buffer, value: float) -> None: + raise NotImplementedError + + +LITERAL_INT: Final = 1 +LITERAL_STR: Final = 2 +LITERAL_BOOL: Final = 3 +LITERAL_FLOAT: Final = 4 +LITERAL_COMPLEX: Final = 5 +LITERAL_NONE: Final = 6 + + +def read_literal(data: Buffer, marker: int) -> int | str | bool | float: + if marker == LITERAL_INT: + return read_int(data) + elif marker == LITERAL_STR: + return read_str(data) + elif marker == LITERAL_BOOL: + return read_bool(data) + elif marker == LITERAL_FLOAT: + return read_float(data) + assert False, f"Unknown literal marker {marker}" + + +def write_literal(data: Buffer, value: int | str | bool | float | complex | None) -> None: + if isinstance(value, bool): + write_int(data, LITERAL_BOOL) + write_bool(data, value) + elif isinstance(value, int): + write_int(data, LITERAL_INT) + write_int(data, value) + elif isinstance(value, str): + write_int(data, LITERAL_STR) + write_str(data, value) + elif isinstance(value, float): + write_int(data, LITERAL_FLOAT) + write_float(data, value) + elif isinstance(value, complex): + write_int(data, LITERAL_COMPLEX) + write_float(data, value.real) + write_float(data, value.imag) + else: + write_int(data, LITERAL_NONE) + + +def read_int_opt(data: Buffer) -> int | None: + if read_bool(data): + return read_int(data) + return None + + +def write_int_opt(data: Buffer, value: int | None) -> None: + if value is not None: + write_bool(data, True) + write_int(data, value) + else: + write_bool(data, False) + + +def read_str_opt(data: Buffer) -> str | None: + if read_bool(data): + return read_str(data) + return None + + +def write_str_opt(data: Buffer, value: str | None) -> None: + if value is not None: + write_bool(data, True) + write_str(data, value) + else: + write_bool(data, False) + + +def read_int_list(data: Buffer) -> list[int]: + size = read_int(data) + return [read_int(data) for _ in range(size)] + + +def write_int_list(data: Buffer, value: list[int]) -> None: + write_int(data, len(value)) + for item in value: + write_int(data, item) + + +def read_str_list(data: Buffer) -> list[str]: + size = read_int(data) + return [read_str(data) for _ in range(size)] + + +def write_str_list(data: Buffer, value: Sequence[str]) -> None: + write_int(data, len(value)) + for item in value: + write_str(data, item) + + +def read_str_opt_list(data: Buffer) -> list[str | None]: + size = read_int(data) + return [read_str_opt(data) for _ in range(size)] + + +def write_str_opt_list(data: Buffer, value: list[str | None]) -> None: + write_int(data, len(value)) + for item in value: + write_str_opt(data, item) diff --git a/mypy/fixup.py b/mypy/fixup.py index 18bdc1c6f497..bec5929ad4b1 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -97,6 +97,8 @@ def visit_type_info(self, info: TypeInfo) -> None: info.declared_metaclass.accept(self.type_fixer) if info.metaclass_type: info.metaclass_type.accept(self.type_fixer) + if info.self_type: + info.self_type.accept(self.type_fixer) if info.alt_promote: info.alt_promote.accept(self.type_fixer) instance = Instance(info, []) diff --git a/mypy/main.py b/mypy/main.py index fd50c7677a11..0f70eb41bb14 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -1056,6 +1056,9 @@ def add_invertible_flag( action="store_true", help="Include fine-grained dependency information in the cache for the mypy daemon", ) + incremental_group.add_argument( + "--fixed-format-cache", action="store_true", help=argparse.SUPPRESS + ) incremental_group.add_argument( "--skip-version-check", action="store_true", diff --git a/mypy/modulefinder.py b/mypy/modulefinder.py index d159736078eb..d61c9ee3ec3f 100644 --- a/mypy/modulefinder.py +++ b/mypy/modulefinder.py @@ -796,6 +796,7 @@ def default_lib_path( custom_typeshed_dir = os.path.abspath(custom_typeshed_dir) typeshed_dir = os.path.join(custom_typeshed_dir, "stdlib") mypy_extensions_dir = os.path.join(custom_typeshed_dir, "stubs", "mypy-extensions") + mypy_native_dir = os.path.join(custom_typeshed_dir, "stubs", "mypy-native") versions_file = os.path.join(typeshed_dir, "VERSIONS") if not os.path.isdir(typeshed_dir) or not os.path.isfile(versions_file): print( @@ -811,11 +812,13 @@ def default_lib_path( data_dir = auto typeshed_dir = os.path.join(data_dir, "typeshed", "stdlib") mypy_extensions_dir = os.path.join(data_dir, "typeshed", "stubs", "mypy-extensions") + mypy_native_dir = os.path.join(data_dir, "typeshed", "stubs", "mypy-native") path.append(typeshed_dir) - # Get mypy-extensions stubs from typeshed, since we treat it as an - # "internal" library, similar to typing and typing-extensions. + # Get mypy-extensions and mypy-native stubs from typeshed, since we treat them as + # "internal" libraries, similar to typing and typing-extensions. path.append(mypy_extensions_dir) + path.append(mypy_native_dir) # Add fallback path that can be used if we have a broken installation. if sys.platform != "win32": diff --git a/mypy/nodes.py b/mypy/nodes.py index 99b9bf72c948..328cd6dade0d 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import os from abc import abstractmethod from collections import defaultdict @@ -13,6 +14,30 @@ from mypy_extensions import trait import mypy.strconv +from mypy.cache import ( + LITERAL_COMPLEX, + LITERAL_NONE, + Buffer, + read_bool, + read_float, + read_int, + read_int_list, + read_int_opt, + read_literal, + read_str, + read_str_list, + read_str_opt, + read_str_opt_list, + write_bool, + write_int, + write_int_list, + write_int_opt, + write_literal, + write_str, + write_str_list, + write_str_opt, + write_str_opt_list, +) from mypy.options import Options from mypy.util import is_sunder, is_typeshed_file, short_type from mypy.visitor import ExpressionVisitor, NodeVisitor, StatementVisitor @@ -240,6 +265,13 @@ def deserialize(cls, data: JsonDict) -> SymbolNode: return method(data) raise NotImplementedError(f"unexpected .class {classname}") + def write(self, data: Buffer) -> None: + raise NotImplementedError(f"Cannot serialize {self.__class__.__name__} instance") + + @classmethod + def read(cls, data: Buffer) -> SymbolNode: + raise NotImplementedError(f"Cannot deserialize {cls.__name__} instance") + # Items: fullname, related symbol table node, surrounding type (if any) Definition: _TypeAlias = tuple[str, "SymbolTableNode", Optional["TypeInfo"]] @@ -368,7 +400,7 @@ def serialize(self) -> JsonDict: "is_stub": self.is_stub, "path": self.path, "is_partial_stub_package": self.is_partial_stub_package, - "future_import_flags": list(self.future_import_flags), + "future_import_flags": sorted(self.future_import_flags), } @classmethod @@ -384,6 +416,28 @@ def deserialize(cls, data: JsonDict) -> MypyFile: tree.future_import_flags = set(data["future_import_flags"]) return tree + def write(self, data: Buffer) -> None: + write_int(data, MYPY_FILE) + write_str(data, self._fullname) + self.names.write(data, self._fullname) + write_bool(data, self.is_stub) + write_str(data, self.path) + write_bool(data, self.is_partial_stub_package) + write_str_list(data, sorted(self.future_import_flags)) + + @classmethod + def read(cls, data: Buffer) -> MypyFile: + assert read_int(data) == MYPY_FILE + tree = MypyFile([], []) + tree._fullname = read_str(data) + tree.names = SymbolTable.read(data) + tree.is_stub = read_bool(data) + tree.path = read_str(data) + tree.is_partial_stub_package = read_bool(data) + tree.future_import_flags = set(read_str_list(data)) + tree.is_cache_skeleton = True + return tree + class ImportBase(Statement): """Base class for all import statements.""" @@ -656,6 +710,41 @@ def deserialize(cls, data: JsonDict) -> OverloadedFuncDef: # NOTE: res.info will be set in the fixup phase. return res + def write(self, data: Buffer) -> None: + write_int(data, OVERLOADED_FUNC_DEF) + write_int(data, len(self.items)) + for item in self.items: + item.write(data) + mypy.types.write_type_opt(data, self.type) + write_str(data, self._fullname) + if self.impl is None: + write_bool(data, False) + else: + write_bool(data, True) + self.impl.write(data) + write_flags(data, self, FUNCBASE_FLAGS) + write_str_opt(data, self.deprecated) + write_int_opt(data, self.setter_index) + + @classmethod + def read(cls, data: Buffer) -> OverloadedFuncDef: + res = OverloadedFuncDef([read_overload_part(data) for _ in range(read_int(data))]) + typ = mypy.types.read_type_opt(data) + if typ is not None: + assert isinstance(typ, mypy.types.ProperType) + res.type = typ + res._fullname = read_str(data) + if read_bool(data): + res.impl = read_overload_part(data) + # set line for empty overload items, as not set in __init__ + if len(res.items) > 0: + res.set_line(res.impl.line) + read_flags(data, res, FUNCBASE_FLAGS) + res.deprecated = read_str_opt(data) + res.setter_index = read_int_opt(data) + # NOTE: res.info will be set in the fixup phase. + return res + def is_dynamic(self) -> bool: return all(item.is_dynamic() for item in self.items) @@ -932,6 +1021,46 @@ def deserialize(cls, data: JsonDict) -> FuncDef: del ret.min_args return ret + def write(self, data: Buffer) -> None: + write_int(data, FUNC_DEF) + write_str(data, self._name) + mypy.types.write_type_opt(data, self.type) + write_str(data, self._fullname) + write_flags(data, self, FUNCDEF_FLAGS) + write_str_opt_list(data, self.arg_names) + write_int_list(data, [int(ak.value) for ak in self.arg_kinds]) + write_int(data, self.abstract_status) + if self.dataclass_transform_spec is None: + write_bool(data, False) + else: + write_bool(data, True) + self.dataclass_transform_spec.write(data) + write_str_opt(data, self.deprecated) + write_str_opt(data, self.original_first_arg) + + @classmethod + def read(cls, data: Buffer) -> FuncDef: + name = read_str(data) + typ: mypy.types.FunctionLike | None = None + if read_bool(data): + typ = mypy.types.read_function_like(data) + ret = FuncDef(name, [], Block([]), typ) + ret._fullname = read_str(data) + read_flags(data, ret, FUNCDEF_FLAGS) + # NOTE: ret.info is set in the fixup phase. + ret.arg_names = read_str_opt_list(data) + ret.arg_kinds = [ARG_KINDS[ak] for ak in read_int_list(data)] + ret.abstract_status = read_int(data) + if read_bool(data): + ret.dataclass_transform_spec = DataclassTransformSpec.read(data) + ret.deprecated = read_str_opt(data) + ret.original_first_arg = read_str_opt(data) + # Leave these uninitialized so that future uses will trigger an error + del ret.arguments + del ret.max_pos + del ret.min_args + return ret + # All types that are both SymbolNodes and FuncBases. See the FuncBase # docstring for the rationale. @@ -1004,6 +1133,22 @@ def deserialize(cls, data: JsonDict) -> Decorator: dec.is_overload = data["is_overload"] return dec + def write(self, data: Buffer) -> None: + write_int(data, DECORATOR) + self.func.write(data) + self.var.write(data) + write_bool(data, self.is_overload) + + @classmethod + def read(cls, data: Buffer) -> Decorator: + assert read_int(data) == FUNC_DEF + func = FuncDef.read(data) + assert read_int(data) == VAR + var = Var.read(data) + dec = Decorator(func, [], var) + dec.is_overload = read_bool(data) + return dec + def is_dynamic(self) -> bool: return self.func.is_dynamic() @@ -1180,6 +1325,35 @@ def deserialize(cls, data: JsonDict) -> Var: v.final_value = data.get("final_value") return v + def write(self, data: Buffer) -> None: + write_int(data, VAR) + write_str(data, self._name) + mypy.types.write_type_opt(data, self.type) + mypy.types.write_type_opt(data, self.setter_type) + write_str(data, self._fullname) + write_flags(data, self, VAR_FLAGS) + write_literal(data, self.final_value) + + @classmethod + def read(cls, data: Buffer) -> Var: + name = read_str(data) + typ = mypy.types.read_type_opt(data) + v = Var(name, typ) + setter_type: mypy.types.CallableType | None = None + if read_bool(data): + assert read_int(data) == mypy.types.CALLABLE_TYPE + setter_type = mypy.types.CallableType.read(data) + v.setter_type = setter_type + v.is_ready = False # Override True default set in __init__ + v._fullname = read_str(data) + read_flags(data, v, VAR_FLAGS) + marker = read_int(data) + if marker == LITERAL_COMPLEX: + v.final_value = complex(read_float(data), read_float(data)) + elif marker != LITERAL_NONE: + v.final_value = read_literal(data, marker) + return v + class ClassDef(Statement): """Class definition""" @@ -1290,6 +1464,22 @@ def deserialize(cls, data: JsonDict) -> ClassDef: res.fullname = data["fullname"] return res + def write(self, data: Buffer) -> None: + write_int(data, CLASS_DEF) + write_str(data, self.name) + mypy.types.write_type_list(data, self.type_vars) + write_str(data, self.fullname) + + @classmethod + def read(cls, data: Buffer) -> ClassDef: + res = ClassDef( + read_str(data), + Block([]), + [mypy.types.read_type_var_like(data) for _ in range(read_int(data))], + ) + res.fullname = read_str(data) + return res + class GlobalDecl(Statement): """Declaration global x, y, ...""" @@ -2707,6 +2897,26 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr: data["variance"], ) + def write(self, data: Buffer) -> None: + write_int(data, TYPE_VAR_EXPR) + write_str(data, self._name) + write_str(data, self._fullname) + mypy.types.write_type_list(data, self.values) + self.upper_bound.write(data) + self.default.write(data) + write_int(data, self.variance) + + @classmethod + def read(cls, data: Buffer) -> TypeVarExpr: + return TypeVarExpr( + read_str(data), + read_str(data), + mypy.types.read_type_list(data), + mypy.types.read_type(data), + mypy.types.read_type(data), + read_int(data), + ) + class ParamSpecExpr(TypeVarLikeExpr): __slots__ = () @@ -2737,6 +2947,24 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr: data["variance"], ) + def write(self, data: Buffer) -> None: + write_int(data, PARAM_SPEC_EXPR) + write_str(data, self._name) + write_str(data, self._fullname) + self.upper_bound.write(data) + self.default.write(data) + write_int(data, self.variance) + + @classmethod + def read(cls, data: Buffer) -> ParamSpecExpr: + return ParamSpecExpr( + read_str(data), + read_str(data), + mypy.types.read_type(data), + mypy.types.read_type(data), + read_int(data), + ) + class TypeVarTupleExpr(TypeVarLikeExpr): """Type variable tuple expression TypeVarTuple(...).""" @@ -2787,6 +3015,28 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr: data["variance"], ) + def write(self, data: Buffer) -> None: + write_int(data, TYPE_VAR_TUPLE_EXPR) + self.tuple_fallback.write(data) + write_str(data, self._name) + write_str(data, self._fullname) + self.upper_bound.write(data) + self.default.write(data) + write_int(data, self.variance) + + @classmethod + def read(cls, data: Buffer) -> TypeVarTupleExpr: + assert read_int(data) == mypy.types.INSTANCE + fallback = mypy.types.Instance.read(data) + return TypeVarTupleExpr( + read_str(data), + read_str(data), + mypy.types.read_type(data), + fallback, + mypy.types.read_type(data), + read_int(data), + ) + class TypeAliasExpr(Expression): """Type alias expression (rvalue).""" @@ -3594,7 +3844,6 @@ def deserialize(cls, data: JsonDict) -> TypeInfo: module_name = data["module_name"] ti = TypeInfo(names, defn, module_name) ti._fullname = data["fullname"] - # TODO: Is there a reason to reconstruct ti.subtypes? ti.abstract_attributes = [(attr[0], attr[1]) for attr in data["abstract_attributes"]] ti.type_vars = data["type_vars"] ti.has_param_spec_type = data["has_param_spec_type"] @@ -3654,6 +3903,100 @@ def deserialize(cls, data: JsonDict) -> TypeInfo: ti.deprecated = data.get("deprecated") return ti + def write(self, data: Buffer) -> None: + write_int(data, TYPE_INFO) + self.names.write(data, self.fullname) + self.defn.write(data) + write_str(data, self.module_name) + write_str(data, self.fullname) + write_str_list(data, [a for a, _ in self.abstract_attributes]) + write_int_list(data, [s for _, s in self.abstract_attributes]) + write_str_list(data, self.type_vars) + write_bool(data, self.has_param_spec_type) + mypy.types.write_type_list(data, self.bases) + write_str_list(data, [c.fullname for c in self.mro]) + mypy.types.write_type_list(data, self._promote) + mypy.types.write_type_opt(data, self.alt_promote) + mypy.types.write_type_opt(data, self.declared_metaclass) + mypy.types.write_type_opt(data, self.metaclass_type) + mypy.types.write_type_opt(data, self.tuple_type) + mypy.types.write_type_opt(data, self.typeddict_type) + write_flags(data, self, TypeInfo.FLAGS) + write_str(data, json.dumps(self.metadata)) + if self.slots is None: + write_bool(data, False) + else: + write_bool(data, True) + write_str_list(data, sorted(self.slots)) + write_str_list(data, self.deletable_attributes) + mypy.types.write_type_opt(data, self.self_type) + if self.dataclass_transform_spec is None: + write_bool(data, False) + else: + write_bool(data, True) + self.dataclass_transform_spec.write(data) + write_str_opt(data, self.deprecated) + + @classmethod + def read(cls, data: Buffer) -> TypeInfo: + names = SymbolTable.read(data) + assert read_int(data) == CLASS_DEF + defn = ClassDef.read(data) + module_name = read_str(data) + ti = TypeInfo(names, defn, module_name) + ti._fullname = read_str(data) + attrs = read_str_list(data) + statuses = read_int_list(data) + ti.abstract_attributes = list(zip(attrs, statuses)) + ti.type_vars = read_str_list(data) + ti.has_param_spec_type = read_bool(data) + num_bases = read_int(data) + ti.bases = [] + for _ in range(num_bases): + assert read_int(data) == mypy.types.INSTANCE + ti.bases.append(mypy.types.Instance.read(data)) + # NOTE: ti.mro will be set in the fixup phase based on these + # names. The reason we need to store the mro instead of just + # recomputing it from base classes has to do with a subtle + # point about fine-grained incremental: the cache files might + # not be loaded until after a class in the mro has changed its + # bases, which causes the mro to change. If we recomputed our + # mro, we would compute the *new* mro, which leaves us with no + # way to detect that the mro has changed! Thus, we need to make + # sure to load the original mro so that once the class is + # rechecked, it can tell that the mro has changed. + ti._mro_refs = read_str_list(data) + ti._promote = cast(list[mypy.types.ProperType], mypy.types.read_type_list(data)) + if read_bool(data): + assert read_int(data) == mypy.types.INSTANCE + ti.alt_promote = mypy.types.Instance.read(data) + if read_bool(data): + assert read_int(data) == mypy.types.INSTANCE + ti.declared_metaclass = mypy.types.Instance.read(data) + if read_bool(data): + assert read_int(data) == mypy.types.INSTANCE + ti.metaclass_type = mypy.types.Instance.read(data) + if read_bool(data): + assert read_int(data) == mypy.types.TUPLE_TYPE + ti.tuple_type = mypy.types.TupleType.read(data) + if read_bool(data): + assert read_int(data) == mypy.types.TYPED_DICT_TYPE + ti.typeddict_type = mypy.types.TypedDictType.read(data) + read_flags(data, ti, TypeInfo.FLAGS) + metadata = read_str(data) + if metadata != "{}": + ti.metadata = json.loads(metadata) + if read_bool(data): + ti.slots = set(read_str_list(data)) + ti.deletable_attributes = read_str_list(data) + if read_bool(data): + assert read_int(data) == mypy.types.TYPE_VAR_TYPE + ti.self_type = mypy.types.TypeVarType.read(data) + if read_bool(data): + ti.dataclass_transform_spec = DataclassTransformSpec.read(data) + ti.deprecated = read_str_opt(data) + return ti + class FakeInfo(TypeInfo): __slots__ = ("msg",) @@ -3882,6 +4225,9 @@ def fullname(self) -> str: def has_param_spec_type(self) -> bool: return any(isinstance(v, mypy.types.ParamSpecType) for v in self.alias_tvars) + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_type_alias(self) + def serialize(self) -> JsonDict: data: JsonDict = { ".class": "TypeAlias", @@ -3896,9 +4242,6 @@ def serialize(self) -> JsonDict: } return data - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_type_alias(self) - @classmethod def deserialize(cls, data: JsonDict) -> TypeAlias: assert data[".class"] == "TypeAlias" @@ -3922,6 +4265,33 @@ def deserialize(cls, data: JsonDict) -> TypeAlias: python_3_12_type_alias=python_3_12_type_alias, ) + def write(self, data: Buffer) -> None: + write_int(data, TYPE_ALIAS) + write_str(data, self._fullname) + self.target.write(data) + mypy.types.write_type_list(data, self.alias_tvars) + write_int(data, self.line) + write_int(data, self.column) + write_bool(data, self.no_args) + write_bool(data, self.normalized) + write_bool(data, self.python_3_12_type_alias) + + @classmethod + def read(cls, data: Buffer) -> TypeAlias: + fullname = read_str(data) + target = mypy.types.read_type(data) + alias_tvars = [mypy.types.read_type_var_like(data) for _ in range(read_int(data))] + return TypeAlias( + target, + fullname, + read_int(data), + read_int(data), + alias_tvars=alias_tvars, + no_args=read_bool(data), + normalized=read_bool(data), + python_3_12_type_alias=read_bool(data), + ) + class PlaceholderNode(SymbolNode): """Temporary symbol node that will later become a real SymbolNode. @@ -4180,6 +4550,49 @@ def deserialize(cls, data: JsonDict) -> SymbolTableNode: stnode.plugin_generated = data["plugin_generated"] return stnode + def write(self, data: Buffer, prefix: str, name: str) -> None: + write_int(data, self.kind) + write_bool(data, self.module_hidden) + write_bool(data, self.module_public) + write_bool(data, self.implicit) + write_bool(data, self.plugin_generated) + + cross_ref = None + if isinstance(self.node, MypyFile): + cross_ref = self.node.fullname + else: + assert self.node is not None, f"{prefix}:{name}" + if prefix is not None: + fullname = self.node.fullname + if ( + "." in fullname + and fullname != prefix + "." + name + and not (isinstance(self.node, Var) and self.node.from_module_getattr) + ): + assert not isinstance( + self.node, PlaceholderNode + ), f"Definition of {fullname} is unexpectedly incomplete" + cross_ref = fullname + + write_str_opt(data, cross_ref) + if cross_ref is None: + assert self.node is not None + self.node.write(data) + + @classmethod + def read(cls, data: Buffer) -> SymbolTableNode: + sym = SymbolTableNode(read_int(data), None) + sym.module_hidden = read_bool(data) + sym.module_public = read_bool(data) + sym.implicit = read_bool(data) + sym.plugin_generated = read_bool(data) + cross_ref = read_str_opt(data) + if cross_ref is None: + sym.node = read_symbol(data) + else: + sym.cross_ref = cross_ref + return sym + class SymbolTable(dict[str, SymbolTableNode]): """Static representation of a namespace dictionary. @@ -4231,6 +4644,29 @@ def deserialize(cls, data: JsonDict) -> SymbolTable: st[key] = SymbolTableNode.deserialize(value) return st + def write(self, data: Buffer, fullname: str) -> None: + size = 0 + for key, value in self.items(): + # Skip __builtins__: it's a reference to the builtins + # module that gets added to every module by + # SemanticAnalyzerPass2.visit_file(), but it shouldn't be + # accessed by users of the module. + if key == "__builtins__" or value.no_serialize: + continue + size += 1 + write_int(data, size) + for key in sorted(self): + value = self[key] + if key == "__builtins__" or value.no_serialize: + continue + write_str(data, key) + value.write(data, fullname, key) + + @classmethod + def read(cls, data: Buffer) -> SymbolTable: + size = read_int(data) + return SymbolTable([(read_str(data), SymbolTableNode.read(data)) for _ in range(size)]) + class DataclassTransformSpec: """Specifies how a dataclass-like transform should be applied. The fields here are based on the @@ -4281,6 +4717,23 @@ def deserialize(cls, data: JsonDict) -> DataclassTransformSpec: field_specifiers=tuple(data.get("field_specifiers", [])), ) + def write(self, data: Buffer) -> None: + write_bool(data, self.eq_default) + write_bool(data, self.order_default) + write_bool(data, self.kw_only_default) + write_bool(data, self.frozen_default) + write_str_list(data, self.field_specifiers) + + @classmethod + def read(cls, data: Buffer) -> DataclassTransformSpec: + return DataclassTransformSpec( + eq_default=read_bool(data), + order_default=read_bool(data), + kw_only_default=read_bool(data), + frozen_default=read_bool(data), + field_specifiers=tuple(read_str_list(data)), + ) + def get_flags(node: Node, names: list[str]) -> list[str]: return [name for name in names if getattr(node, name)] @@ -4291,6 +4744,17 @@ def set_flags(node: Node, flags: list[str]) -> None: setattr(node, name, True) +def write_flags(data: Buffer, node: SymbolNode, flags: list[str]) -> None: + for flag in flags: + write_bool(data, getattr(node, flag)) + + +def read_flags(data: Buffer, node: SymbolNode, flags: list[str]) -> None: + for flag in flags: + if read_bool(data): + setattr(node, flag, True) + + def get_member_expr_fullname(expr: MemberExpr) -> str | None: """Return the qualified name representation of a member expression. @@ -4406,3 +4870,49 @@ def local_definitions( yield fullname, symnode, info if isinstance(node, TypeInfo): yield from local_definitions(node.names, fullname, node) + + +MYPY_FILE: Final = 0 +OVERLOADED_FUNC_DEF: Final = 1 +FUNC_DEF: Final = 2 +DECORATOR: Final = 3 +VAR: Final = 4 +TYPE_VAR_EXPR: Final = 5 +PARAM_SPEC_EXPR: Final = 6 +TYPE_VAR_TUPLE_EXPR: Final = 7 +TYPE_INFO: Final = 8 +TYPE_ALIAS: Final = 9 +CLASS_DEF: Final = 10 + + +def read_symbol(data: Buffer) -> mypy.nodes.SymbolNode: + marker = read_int(data) + # The branches here are ordered manually by type "popularity". + if marker == VAR: + return mypy.nodes.Var.read(data) + if marker == FUNC_DEF: + return mypy.nodes.FuncDef.read(data) + if marker == DECORATOR: + return mypy.nodes.Decorator.read(data) + if marker == TYPE_INFO: + return mypy.nodes.TypeInfo.read(data) + if marker == OVERLOADED_FUNC_DEF: + return mypy.nodes.OverloadedFuncDef.read(data) + if marker == TYPE_VAR_EXPR: + return mypy.nodes.TypeVarExpr.read(data) + if marker == TYPE_ALIAS: + return mypy.nodes.TypeAlias.read(data) + if marker == PARAM_SPEC_EXPR: + return mypy.nodes.ParamSpecExpr.read(data) + if marker == TYPE_VAR_TUPLE_EXPR: + return mypy.nodes.TypeVarTupleExpr.read(data) + assert False, f"Unknown symbol marker {marker}" + + +def read_overload_part(data: Buffer) -> OverloadPart: + marker = read_int(data) + if marker == DECORATOR: + return Decorator.read(data) + if marker == FUNC_DEF: + return FuncDef.read(data) + assert False, f"Invalid marker for an OverloadPart {marker}" diff --git a/mypy/options.py b/mypy/options.py index 6d7eca772888..ad4b26cca095 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -72,6 +72,7 @@ class BuildType: "disable_bytearray_promotion", "disable_memoryview_promotion", "strict_bytes", + "fixed_format_cache", } ) - {"debug_cache"} @@ -286,6 +287,7 @@ def __init__(self) -> None: self.incremental = True self.cache_dir = defaults.CACHE_DIR self.sqlite_cache = False + self.fixed_format_cache = False self.debug_cache = False self.skip_version_check = False self.skip_cache_mtime_checks = False diff --git a/mypy/types.py b/mypy/types.py index 26c5b474ba6c..d5df9dc6e7c0 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -10,6 +10,25 @@ import mypy.nodes from mypy.bogus_type import Bogus +from mypy.cache import ( + Buffer, + read_bool, + read_int, + read_int_list, + read_literal, + read_str, + read_str_list, + read_str_opt, + read_str_opt_list, + write_bool, + write_int, + write_int_list, + write_literal, + write_str, + write_str_list, + write_str_opt, + write_str_opt_list, +) from mypy.nodes import ARG_KINDS, ARG_POS, ARG_STAR, ARG_STAR2, INVARIANT, ArgKind, SymbolNode from mypy.options import Options from mypy.state import state @@ -273,6 +292,13 @@ def serialize(self) -> JsonDict | str: def deserialize(cls, data: JsonDict) -> Type: raise NotImplementedError(f"Cannot deserialize {cls.__name__} instance") + def write(self, data: Buffer) -> None: + raise NotImplementedError(f"Cannot serialize {self.__class__.__name__} instance") + + @classmethod + def read(cls, data: Buffer) -> Type: + raise NotImplementedError(f"Cannot deserialize {cls.__name__} instance") + def is_singleton_type(self) -> bool: return False @@ -388,6 +414,11 @@ def can_be_false_default(self) -> bool: return self.alias.target.can_be_false return super().can_be_false_default() + def copy_modified(self, *, args: list[Type] | None = None) -> TypeAliasType: + return TypeAliasType( + self.alias, args if args is not None else self.args.copy(), self.line, self.column + ) + def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_type_alias_type(self) @@ -421,10 +452,17 @@ def deserialize(cls, data: JsonDict) -> TypeAliasType: alias.type_ref = data["type_ref"] return alias - def copy_modified(self, *, args: list[Type] | None = None) -> TypeAliasType: - return TypeAliasType( - self.alias, args if args is not None else self.args.copy(), self.line, self.column - ) + def write(self, data: Buffer) -> None: + write_int(data, TYPE_ALIAS_TYPE) + write_type_list(data, self.args) + assert self.alias is not None + write_str(data, self.alias.fullname) + + @classmethod + def read(cls, data: Buffer) -> TypeAliasType: + alias = TypeAliasType(None, read_type_list(data)) + alias.type_ref = read_str(data) + return alias class TypeGuardedType(Type): @@ -693,6 +731,29 @@ def deserialize(cls, data: JsonDict) -> TypeVarType: variance=data["variance"], ) + def write(self, data: Buffer) -> None: + write_int(data, TYPE_VAR_TYPE) + write_str(data, self.name) + write_str(data, self.fullname) + write_int(data, self.id.raw_id) + write_str(data, self.id.namespace) + write_type_list(data, self.values) + self.upper_bound.write(data) + self.default.write(data) + write_int(data, self.variance) + + @classmethod + def read(cls, data: Buffer) -> TypeVarType: + return TypeVarType( + read_str(data), + read_str(data), + TypeVarId(read_int(data), namespace=read_str(data)), + read_type_list(data), + read_type(data), + read_type(data), + read_int(data), + ) + class ParamSpecFlavor: # Simple ParamSpec reference such as "P" @@ -822,6 +883,31 @@ def deserialize(cls, data: JsonDict) -> ParamSpecType: prefix=Parameters.deserialize(data["prefix"]), ) + def write(self, data: Buffer) -> None: + write_int(data, PARAM_SPEC_TYPE) + self.prefix.write(data) + write_str(data, self.name) + write_str(data, self.fullname) + write_int(data, self.id.raw_id) + write_str(data, self.id.namespace) + write_int(data, self.flavor) + self.upper_bound.write(data) + self.default.write(data) + + @classmethod + def read(cls, data: Buffer) -> ParamSpecType: + assert read_int(data) == PARAMETERS + prefix = Parameters.read(data) + return ParamSpecType( + read_str(data), + read_str(data), + TypeVarId(read_int(data), namespace=read_str(data)), + read_int(data), + read_type(data), + read_type(data), + prefix=prefix, + ) + class TypeVarTupleType(TypeVarLikeType): """Type that refers to a TypeVarTuple. @@ -877,6 +963,31 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleType: min_len=data["min_len"], ) + def write(self, data: Buffer) -> None: + write_int(data, TYPE_VAR_TUPLE_TYPE) + self.tuple_fallback.write(data) + write_str(data, self.name) + write_str(data, self.fullname) + write_int(data, self.id.raw_id) + write_str(data, self.id.namespace) + self.upper_bound.write(data) + self.default.write(data) + write_int(data, self.min_len) + + @classmethod + def read(cls, data: Buffer) -> TypeVarTupleType: + assert read_int(data) == INSTANCE + fallback = Instance.read(data) + return TypeVarTupleType( + read_str(data), + read_str(data), + TypeVarId(read_int(data), namespace=read_str(data)), + read_type(data), + fallback, + read_type(data), + min_len=read_int(data), + ) + def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_type_var_tuple(self) @@ -1008,6 +1119,22 @@ def deserialize(cls, data: JsonDict) -> UnboundType: original_str_fallback=data["expr_fallback"], ) + def write(self, data: Buffer) -> None: + write_int(data, UNBOUND_TYPE) + write_str(data, self.name) + write_type_list(data, self.args) + write_str_opt(data, self.original_str_expr) + write_str_opt(data, self.original_str_fallback) + + @classmethod + def read(cls, data: Buffer) -> UnboundType: + return UnboundType( + read_str(data), + read_type_list(data), + original_str_expr=read_str_opt(data), + original_str_fallback=read_str_opt(data), + ) + class CallableArgument(ProperType): """Represents a Arg(type, 'name') inside a Callable's type list. @@ -1102,6 +1229,14 @@ def accept(self, visitor: TypeVisitor[T]) -> T: def serialize(self) -> JsonDict: return {".class": "UnpackType", "type": self.type.serialize()} + def write(self, data: Buffer) -> None: + write_int(data, UNPACK_TYPE) + self.type.write(data) + + @classmethod + def read(cls, data: Buffer) -> UnpackType: + return UnpackType(read_type(data)) + @classmethod def deserialize(cls, data: JsonDict) -> UnpackType: assert data[".class"] == "UnpackType" @@ -1203,6 +1338,21 @@ def deserialize(cls, data: JsonDict) -> AnyType: data["missing_import_name"], ) + def write(self, data: Buffer) -> None: + write_int(data, ANY_TYPE) + write_type_opt(data, self.source_any) + write_int(data, self.type_of_any) + write_str_opt(data, self.missing_import_name) + + @classmethod + def read(cls, data: Buffer) -> AnyType: + if read_bool(data): + assert read_int(data) == ANY_TYPE + source_any = AnyType.read(data) + else: + source_any = None + return AnyType(read_int(data), source_any, read_str_opt(data)) + class UninhabitedType(ProperType): """This type has no members. @@ -1249,6 +1399,13 @@ def deserialize(cls, data: JsonDict) -> UninhabitedType: assert data[".class"] == "UninhabitedType" return UninhabitedType() + def write(self, data: Buffer) -> None: + write_int(data, UNINHABITED_TYPE) + + @classmethod + def read(cls, data: Buffer) -> UninhabitedType: + return UninhabitedType() + class NoneType(ProperType): """The type of 'None'. @@ -1281,6 +1438,13 @@ def deserialize(cls, data: JsonDict) -> NoneType: assert data[".class"] == "NoneType" return NoneType() + def write(self, data: Buffer) -> None: + write_int(data, NONE_TYPE) + + @classmethod + def read(cls, data: Buffer) -> NoneType: + return NoneType() + def is_singleton_type(self) -> bool: return True @@ -1328,6 +1492,14 @@ def deserialize(cls, data: JsonDict) -> DeletedType: assert data[".class"] == "DeletedType" return DeletedType(data["source"]) + def write(self, data: Buffer) -> None: + write_int(data, DELETED_TYPE) + write_str_opt(data, self.source) + + @classmethod + def read(cls, data: Buffer) -> DeletedType: + return DeletedType(read_str_opt(data)) + # Fake TypeInfo to be used as a placeholder during Instance de-serialization. NOT_READY: Final = mypy.nodes.FakeInfo("De-serialization failure: TypeInfo not fixed") @@ -1370,7 +1542,7 @@ def serialize(self) -> JsonDict: return { ".class": "ExtraAttrs", "attrs": {k: v.serialize() for k, v in self.attrs.items()}, - "immutable": list(self.immutable), + "immutable": sorted(self.immutable), "mod_name": self.mod_name, } @@ -1383,6 +1555,15 @@ def deserialize(cls, data: JsonDict) -> ExtraAttrs: data["mod_name"], ) + def write(self, data: Buffer) -> None: + write_type_map(data, self.attrs) + write_str_list(data, sorted(self.immutable)) + write_str_opt(data, self.mod_name) + + @classmethod + def read(cls, data: Buffer) -> ExtraAttrs: + return ExtraAttrs(read_type_map(data), set(read_str_list(data)), read_str_opt(data)) + class Instance(ProperType): """An instance type of form C[T1, ..., Tn]. @@ -1519,6 +1700,29 @@ def deserialize(cls, data: JsonDict | str) -> Instance: inst.extra_attrs = ExtraAttrs.deserialize(data["extra_attrs"]) return inst + def write(self, data: Buffer) -> None: + write_int(data, INSTANCE) + write_str(data, self.type.fullname) + write_type_list(data, self.args) + write_type_opt(data, self.last_known_value) + if self.extra_attrs is None: + write_bool(data, False) + else: + write_bool(data, True) + self.extra_attrs.write(data) + + @classmethod + def read(cls, data: Buffer) -> Instance: + type_ref = read_str(data) + inst = Instance(NOT_READY, read_type_list(data)) + inst.type_ref = type_ref + if read_bool(data): + assert read_int(data) == LITERAL_TYPE + inst.last_known_value = LiteralType.read(data) + if read_bool(data): + inst.extra_attrs = ExtraAttrs.read(data) + return inst + def copy_modified( self, *, @@ -1795,6 +1999,26 @@ def deserialize(cls, data: JsonDict) -> Parameters: imprecise_arg_kinds=data["imprecise_arg_kinds"], ) + def write(self, data: Buffer) -> None: + write_int(data, PARAMETERS) + write_type_list(data, self.arg_types) + write_int_list(data, [int(x.value) for x in self.arg_kinds]) + write_str_opt_list(data, self.arg_names) + write_type_list(data, self.variables) + write_bool(data, self.imprecise_arg_kinds) + + @classmethod + def read(cls, data: Buffer) -> Parameters: + return Parameters( + read_type_list(data), + # This is a micro-optimization until mypyc gets dedicated enum support. Otherwise, + # we would spend ~20% of types deserialization time in Enum.__call__(). + [ARG_KINDS[ak] for ak in read_int_list(data)], + read_str_opt_list(data), + variables=[read_type_var_like(data) for _ in range(read_int(data))], + imprecise_arg_kinds=read_bool(data), + ) + def __hash__(self) -> int: return hash( ( @@ -2297,6 +2521,46 @@ def deserialize(cls, data: JsonDict) -> CallableType: unpack_kwargs=data["unpack_kwargs"], ) + def write(self, data: Buffer) -> None: + write_int(data, CALLABLE_TYPE) + self.fallback.write(data) + write_type_list(data, self.arg_types) + write_int_list(data, [int(x.value) for x in self.arg_kinds]) + write_str_opt_list(data, self.arg_names) + self.ret_type.write(data) + write_str_opt(data, self.name) + write_type_list(data, self.variables) + write_bool(data, self.is_ellipsis_args) + write_bool(data, self.implicit) + write_bool(data, self.is_bound) + write_type_opt(data, self.type_guard) + write_type_opt(data, self.type_is) + write_bool(data, self.from_concatenate) + write_bool(data, self.imprecise_arg_kinds) + write_bool(data, self.unpack_kwargs) + + @classmethod + def read(cls, data: Buffer) -> CallableType: + assert read_int(data) == INSTANCE + fallback = Instance.read(data) + return CallableType( + read_type_list(data), + [ARG_KINDS[ak] for ak in read_int_list(data)], + read_str_opt_list(data), + read_type(data), + fallback, + name=read_str_opt(data), + variables=[read_type_var_like(data) for _ in range(read_int(data))], + is_ellipsis_args=read_bool(data), + implicit=read_bool(data), + is_bound=read_bool(data), + type_guard=read_type_opt(data), + type_is=read_type_opt(data), + from_concatenate=read_bool(data), + imprecise_arg_kinds=read_bool(data), + unpack_kwargs=read_bool(data), + ) + # This is a little safety net to prevent reckless special-casing of callables # that can potentially break Unpack[...] with **kwargs. @@ -2372,6 +2636,19 @@ def deserialize(cls, data: JsonDict) -> Overloaded: assert data[".class"] == "Overloaded" return Overloaded([CallableType.deserialize(t) for t in data["items"]]) + def write(self, data: Buffer) -> None: + write_int(data, OVERLOADED) + write_type_list(data, self.items) + + @classmethod + def read(cls, data: Buffer) -> Overloaded: + items = [] + num_overloads = read_int(data) + for _ in range(num_overloads): + assert read_int(data) == CALLABLE_TYPE + items.append(CallableType.read(data)) + return Overloaded(items) + class TupleType(ProperType): """The tuple type Tuple[T1, ..., Tn] (at least one type argument). @@ -2468,6 +2745,18 @@ def deserialize(cls, data: JsonDict) -> TupleType: implicit=data["implicit"], ) + def write(self, data: Buffer) -> None: + write_int(data, TUPLE_TYPE) + self.partial_fallback.write(data) + write_type_list(data, self.items) + write_bool(data, self.implicit) + + @classmethod + def read(cls, data: Buffer) -> TupleType: + assert read_int(data) == INSTANCE + fallback = Instance.read(data) + return TupleType(read_type_list(data), fallback, implicit=read_bool(data)) + def copy_modified( self, *, fallback: Instance | None = None, items: list[Type] | None = None ) -> TupleType: @@ -2638,6 +2927,21 @@ def deserialize(cls, data: JsonDict) -> TypedDictType: Instance.deserialize(data["fallback"]), ) + def write(self, data: Buffer) -> None: + write_int(data, TYPED_DICT_TYPE) + self.fallback.write(data) + write_type_map(data, self.items) + write_str_list(data, sorted(self.required_keys)) + write_str_list(data, sorted(self.readonly_keys)) + + @classmethod + def read(cls, data: Buffer) -> TypedDictType: + assert read_int(data) == INSTANCE + fallback = Instance.read(data) + return TypedDictType( + read_type_map(data), set(read_str_list(data)), set(read_str_list(data)), fallback + ) + @property def is_final(self) -> bool: return self.fallback.type.is_final @@ -2886,6 +3190,18 @@ def deserialize(cls, data: JsonDict) -> LiteralType: assert data[".class"] == "LiteralType" return LiteralType(value=data["value"], fallback=Instance.deserialize(data["fallback"])) + def write(self, data: Buffer) -> None: + write_int(data, LITERAL_TYPE) + self.fallback.write(data) + write_literal(data, self.value) + + @classmethod + def read(cls, data: Buffer) -> LiteralType: + assert read_int(data) == INSTANCE + fallback = Instance.read(data) + marker = read_int(data) + return LiteralType(read_literal(data, marker), fallback) + def is_singleton_type(self) -> bool: return self.is_enum_literal() or isinstance(self.value, bool) @@ -2987,6 +3303,15 @@ def deserialize(cls, data: JsonDict) -> UnionType: uses_pep604_syntax=data["uses_pep604_syntax"], ) + def write(self, data: Buffer) -> None: + write_int(data, UNION_TYPE) + write_type_list(data, self.items) + write_bool(data, self.uses_pep604_syntax) + + @classmethod + def read(cls, data: Buffer) -> UnionType: + return UnionType(read_type_list(data), uses_pep604_syntax=read_bool(data)) + class PartialType(ProperType): """Type such as List[?] where type arguments are unknown, or partial None type. @@ -3123,6 +3448,14 @@ def deserialize(cls, data: JsonDict) -> Type: assert data[".class"] == "TypeType" return TypeType.make_normalized(deserialize_type(data["item"])) + def write(self, data: Buffer) -> None: + write_int(data, TYPE_TYPE) + self.item.write(data) + + @classmethod + def read(cls, data: Buffer) -> Type: + return TypeType.make_normalized(read_type(data)) + class PlaceholderType(ProperType): """Temporary, yet-unknown type during semantic analysis. @@ -3783,6 +4116,128 @@ def type_vars_as_args(type_vars: Sequence[TypeVarLikeType]) -> tuple[Type, ...]: return tuple(args) +TYPE_ALIAS_TYPE: Final = 1 +TYPE_VAR_TYPE: Final = 2 +PARAM_SPEC_TYPE: Final = 3 +TYPE_VAR_TUPLE_TYPE: Final = 4 +UNBOUND_TYPE: Final = 5 +UNPACK_TYPE: Final = 6 +ANY_TYPE: Final = 7 +UNINHABITED_TYPE: Final = 8 +NONE_TYPE: Final = 9 +DELETED_TYPE: Final = 10 +INSTANCE: Final = 11 +CALLABLE_TYPE: Final = 12 +OVERLOADED: Final = 13 +TUPLE_TYPE: Final = 14 +TYPED_DICT_TYPE: Final = 15 +LITERAL_TYPE: Final = 16 +UNION_TYPE: Final = 17 +TYPE_TYPE: Final = 18 +PARAMETERS: Final = 19 + + +def read_type(data: Buffer) -> Type: + marker = read_int(data) + # The branches here are ordered manually by type "popularity". + if marker == INSTANCE: + return Instance.read(data) + if marker == ANY_TYPE: + return AnyType.read(data) + if marker == TYPE_VAR_TYPE: + return TypeVarType.read(data) + if marker == CALLABLE_TYPE: + return CallableType.read(data) + if marker == NONE_TYPE: + return NoneType.read(data) + if marker == UNION_TYPE: + return UnionType.read(data) + if marker == LITERAL_TYPE: + return LiteralType.read(data) + if marker == TYPE_ALIAS_TYPE: + return TypeAliasType.read(data) + if marker == TUPLE_TYPE: + return TupleType.read(data) + if marker == TYPED_DICT_TYPE: + return TypedDictType.read(data) + if marker == TYPE_TYPE: + return TypeType.read(data) + if marker == OVERLOADED: + return Overloaded.read(data) + if marker == PARAM_SPEC_TYPE: + return ParamSpecType.read(data) + if marker == TYPE_VAR_TUPLE_TYPE: + return TypeVarTupleType.read(data) + if marker == UNPACK_TYPE: + return UnpackType.read(data) + if marker == PARAMETERS: + return Parameters.read(data) + if marker == UNINHABITED_TYPE: + return UninhabitedType.read(data) + if marker == UNBOUND_TYPE: + return UnboundType.read(data) + if marker == DELETED_TYPE: + return DeletedType.read(data) + assert False, f"Unknown type marker {marker}" + + +def read_function_like(data: Buffer) -> FunctionLike: + marker = read_int(data) + if marker == CALLABLE_TYPE: + return CallableType.read(data) + if marker == OVERLOADED: + return Overloaded.read(data) + assert False, f"Invalid type marker for FunctionLike {marker}" + + +def read_type_var_like(data: Buffer) -> TypeVarLikeType: + marker = read_int(data) + if marker == TYPE_VAR_TYPE: + return TypeVarType.read(data) + if marker == PARAM_SPEC_TYPE: + return ParamSpecType.read(data) + if marker == TYPE_VAR_TUPLE_TYPE: + return TypeVarTupleType.read(data) + assert False, f"Invalid type marker for TypeVarLikeType {marker}" + + +def read_type_opt(data: Buffer) -> Type | None: + if read_bool(data): + return read_type(data) + return None + + +def write_type_opt(data: Buffer, value: Type | None) -> None: + if value is not None: + write_bool(data, True) + value.write(data) + else: + write_bool(data, False) + + +def read_type_list(data: Buffer) -> list[Type]: + size = read_int(data) + return [read_type(data) for _ in range(size)] + + +def write_type_list(data: Buffer, value: Sequence[Type]) -> None: + write_int(data, len(value)) + for item in value: + item.write(data) + + +def read_type_map(data: Buffer) -> dict[str, Type]: + size = read_int(data) + return {read_str(data): read_type(data) for _ in range(size)} + + +def write_type_map(data: Buffer, value: dict[str, Type]) -> None: + write_int(data, len(value)) + for key in sorted(value): + write_str(data, key) + value[key].write(data) + + # This cyclic import is unfortunate, but to avoid it we would need to move away all uses # of get_proper_type() from types.py. Majority of them have been removed, but few remaining # are quite tricky to get rid of, but ultimately we want to do it at some point. diff --git a/mypy/typeshed/stubs/mypy-native/METADATA.toml b/mypy/typeshed/stubs/mypy-native/METADATA.toml new file mode 100644 index 000000000000..76574b01cb4b --- /dev/null +++ b/mypy/typeshed/stubs/mypy-native/METADATA.toml @@ -0,0 +1 @@ +version = "0.0.*" diff --git a/mypy/typeshed/stubs/mypy-native/native_internal.pyi b/mypy/typeshed/stubs/mypy-native/native_internal.pyi new file mode 100644 index 000000000000..bc1f570a8e9c --- /dev/null +++ b/mypy/typeshed/stubs/mypy-native/native_internal.pyi @@ -0,0 +1,12 @@ +class Buffer: + def __init__(self, source: bytes = ...) -> None: ... + def getvalue(self) -> bytes: ... + +def write_bool(data: Buffer, value: bool) -> None: ... +def read_bool(data: Buffer) -> bool: ... +def write_str(data: Buffer, value: str) -> None: ... +def read_str(data: Buffer) -> str: ... +def write_float(data: Buffer, value: float) -> None: ... +def read_float(data: Buffer) -> float: ... +def write_int(data: Buffer, value: int) -> None: ... +def read_int(data: Buffer) -> int: ... diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 4ad2a52c1036..6980c9cee419 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -56,6 +56,7 @@ ) from mypyc.ir.pprint import format_func from mypyc.ir.rtypes import ( + KNOWN_NATIVE_TYPES, RArray, RInstance, RPrimitive, @@ -181,7 +182,7 @@ def check_op_sources_valid(fn: FuncIR) -> list[FnError]: set_rprimitive.name, tuple_rprimitive.name, range_rprimitive.name, -} +} | set(KNOWN_NATIVE_TYPES) def can_coerce_to(src: RType, dest: RType) -> bool: diff --git a/mypyc/build.py b/mypyc/build.py index 4a2d703b9f10..efbd0dce31db 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -492,6 +492,8 @@ def mypycify( strict_dunder_typing: bool = False, group_name: str | None = None, log_trace: bool = False, + depends_on_native_internal: bool = False, + install_native_libs: bool = False, ) -> list[Extension]: """Main entry point to building using mypyc. @@ -542,6 +544,11 @@ def mypycify( mypyc_trace.txt (derived from executed operations). This is useful for performance analysis, such as analyzing which primitive ops are used the most and on which lines. + depends_on_native_internal: This is True only for mypy itself. + install_native_libs: If True, also build the native extension modules. Normally, + those are build and published on PyPI separately, but during + tests, we want to use their development versions (i.e. from + current commit). """ # Figure out our configuration @@ -555,6 +562,7 @@ def mypycify( strict_dunder_typing=strict_dunder_typing, group_name=group_name, log_trace=log_trace, + depends_on_native_internal=depends_on_native_internal, ) # Generate all the actual important C code @@ -653,4 +661,21 @@ def mypycify( build_single_module(group_sources, cfilenames + shared_cfilenames, cflags) ) + if install_native_libs: + for name in ["native_internal.c"] + RUNTIME_C_FILES: + rt_file = os.path.join(build_dir, name) + with open(os.path.join(include_dir(), name), encoding="utf-8") as f: + write_file(rt_file, f.read()) + extensions.append( + get_extension()( + "native_internal", + sources=[ + os.path.join(build_dir, file) + for file in ["native_internal.c"] + RUNTIME_C_FILES + ], + include_dirs=[include_dir()], + extra_compile_args=cflags, + ) + ) + return extensions diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 8c4a69cfa3cb..9ca761bd8ac5 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -39,6 +39,7 @@ is_int64_rprimitive, is_int_rprimitive, is_list_rprimitive, + is_native_rprimitive, is_none_rprimitive, is_object_rprimitive, is_optional_type, @@ -704,7 +705,7 @@ def emit_cast( self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) self.emit_line("}") - elif is_object_rprimitive(typ): + elif is_object_rprimitive(typ) or is_native_rprimitive(typ): if declare_dest: self.emit_line(f"PyObject *{dest};") self.emit_arg_check(src, dest, typ, "", optional) diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 1e49b1320b26..e31fcf8ea0c9 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -601,6 +601,8 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]: ext_declarations.emit_line(f"#define MYPYC_NATIVE{self.group_suffix}_H") ext_declarations.emit_line("#include ") ext_declarations.emit_line("#include ") + if self.compiler_options.depends_on_native_internal: + ext_declarations.emit_line("#include ") declarations = Emitter(self.context) declarations.emit_line(f"#ifndef MYPYC_NATIVE_INTERNAL{self.group_suffix}_H") @@ -1027,6 +1029,10 @@ def emit_module_exec_func( declaration = f"int CPyExec_{exported_name(module_name)}(PyObject *module)" module_static = self.module_internal_static_name(module_name, emitter) emitter.emit_lines(declaration, "{") + if self.compiler_options.depends_on_native_internal: + emitter.emit_line("if (import_native_internal() < 0) {") + emitter.emit_line("return -1;") + emitter.emit_line("}") emitter.emit_line("PyObject* modname = NULL;") if self.multi_phase_init: emitter.emit_line(f"{module_static} = module;") @@ -1187,7 +1193,7 @@ def declare_internal_globals(self, module_name: str, emitter: Emitter) -> None: self.declare_global("PyObject *", static_name) def module_internal_static_name(self, module_name: str, emitter: Emitter) -> str: - return emitter.static_name(module_name + "_internal", None, prefix=MODULE_PREFIX) + return emitter.static_name(module_name + "__internal", None, prefix=MODULE_PREFIX) def declare_module(self, module_name: str, emitter: Emitter) -> None: # We declare two globals for each compiled module: diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 3c2fbfec1035..667ff60b0204 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -512,6 +512,15 @@ def __hash__(self) -> int: # Python range object. range_rprimitive: Final = RPrimitive("builtins.range", is_unboxed=False, is_refcounted=True) +KNOWN_NATIVE_TYPES: Final = { + name: RPrimitive(name, is_unboxed=False, is_refcounted=True) + for name in ["native_internal.Buffer"] +} + + +def is_native_rprimitive(rtype: RType) -> bool: + return isinstance(rtype, RPrimitive) and rtype.name in KNOWN_NATIVE_TYPES + def is_tagged(rtype: RType) -> TypeGuard[RPrimitive]: return rtype is int_rprimitive or rtype is short_int_rprimitive diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 815688d90fb6..05aa0e45c569 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -25,6 +25,7 @@ from mypyc.ir.class_ir import ClassIR from mypyc.ir.func_ir import FuncDecl, FuncSignature, RuntimeArg from mypyc.ir.rtypes import ( + KNOWN_NATIVE_TYPES, RInstance, RTuple, RType, @@ -119,6 +120,8 @@ def type_to_rtype(self, typ: Type | None) -> RType: return int16_rprimitive elif typ.type.fullname == "mypy_extensions.u8": return uint8_rprimitive + elif typ.type.fullname in KNOWN_NATIVE_TYPES: + return KNOWN_NATIVE_TYPES[typ.type.fullname] else: return object_rprimitive elif isinstance(typ, TupleType): diff --git a/mypyc/lib-rt/native_internal.c b/mypyc/lib-rt/native_internal.c new file mode 100644 index 000000000000..11a3fafee56f --- /dev/null +++ b/mypyc/lib-rt/native_internal.c @@ -0,0 +1,510 @@ +#define PY_SSIZE_T_CLEAN +#include +#include "CPy.h" +#define NATIVE_INTERNAL_MODULE +#include "native_internal.h" + +#define START_SIZE 512 + +typedef struct { + PyObject_HEAD + Py_ssize_t pos; + Py_ssize_t end; + Py_ssize_t size; + char *buf; + PyObject *source; +} BufferObject; + +static PyTypeObject BufferType; + +static PyObject* +Buffer_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + if (type != &BufferType) { + PyErr_SetString(PyExc_TypeError, "Buffer should not be subclassed"); + return NULL; + } + + BufferObject *self = (BufferObject *)type->tp_alloc(type, 0); + if (self != NULL) { + self->pos = 0; + self->end = 0; + self->size = 0; + self->buf = NULL; + } + return (PyObject *) self; +} + + +static int +Buffer_init_internal(BufferObject *self, PyObject *source) { + if (source) { + if (!PyBytes_Check(source)) { + PyErr_SetString(PyExc_TypeError, "source must be a bytes object"); + return -1; + } + self->size = PyBytes_GET_SIZE(source); + self->end = self->size; + // This returns a pointer to internal bytes data, so make our own copy. + char *buf = PyBytes_AsString(source); + self->buf = PyMem_Malloc(self->size); + memcpy(self->buf, buf, self->size); + } else { + self->buf = PyMem_Malloc(START_SIZE); + self->size = START_SIZE; + } + return 0; +} + +static PyObject* +Buffer_internal(PyObject *source) { + BufferObject *self = (BufferObject *)BufferType.tp_alloc(&BufferType, 0); + if (self == NULL) + return NULL; + self->pos = 0; + self->end = 0; + self->size = 0; + self->buf = NULL; + if (Buffer_init_internal(self, source) == -1) { + Py_DECREF(self); + return NULL; + } + return (PyObject *)self; +} + +static PyObject* +Buffer_internal_empty(void) { + return Buffer_internal(NULL); +} + +static int +Buffer_init(BufferObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"source", NULL}; + PyObject *source = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O", kwlist, &source)) + return -1; + + return Buffer_init_internal(self, source); +} + +static void +Buffer_dealloc(BufferObject *self) +{ + PyMem_Free(self->buf); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject* +Buffer_getvalue_internal(PyObject *self) +{ + return PyBytes_FromStringAndSize(((BufferObject *)self)->buf, ((BufferObject *)self)->end); +} + +static PyObject* +Buffer_getvalue(BufferObject *self, PyObject *Py_UNUSED(ignored)) +{ + return PyBytes_FromStringAndSize(self->buf, self->end); +} + +static PyMethodDef Buffer_methods[] = { + {"getvalue", (PyCFunction) Buffer_getvalue, METH_NOARGS, + "Return the buffer content as bytes object" + }, + {NULL} /* Sentinel */ +}; + +static PyTypeObject BufferType = { + .ob_base = PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "Buffer", + .tp_doc = PyDoc_STR("Mypy cache buffer objects"), + .tp_basicsize = sizeof(BufferObject), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_new = Buffer_new, + .tp_init = (initproc) Buffer_init, + .tp_dealloc = (destructor) Buffer_dealloc, + .tp_methods = Buffer_methods, +}; + +static inline char +_check_buffer(PyObject *data) { + if (Py_TYPE(data) != &BufferType) { + PyErr_Format( + PyExc_TypeError, "data must be a Buffer object, got %s", Py_TYPE(data)->tp_name + ); + return 2; + } + return 1; +} + +static inline char +_check_size(BufferObject *data, Py_ssize_t need) { + Py_ssize_t target = data->pos + need; + if (target <= data->size) + return 1; + do + data->size *= 2; + while (target >= data->size); + data->buf = PyMem_Realloc(data->buf, data->size); + if (!data->buf) { + PyErr_NoMemory(); + return 2; + } + return 1; +} + +static inline char +_check_read(BufferObject *data, Py_ssize_t need) { + if (data->pos + need > data->end) { + PyErr_SetString(PyExc_ValueError, "reading past the buffer end"); + return 2; + } + return 1; +} + +static char +read_bool_internal(PyObject *data) { + if (_check_buffer(data) == 2) + return 2; + + if (_check_read((BufferObject *)data, 1) == 2) + return 2; + char *buf = ((BufferObject *)data)->buf; + char res = buf[((BufferObject *)data)->pos]; + ((BufferObject *)data)->pos += 1; + return res; +} + +static PyObject* +read_bool(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", NULL}; + PyObject *data = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) + return NULL; + char res = read_bool_internal(data); + if (res == 2) + return NULL; + PyObject *retval = res ? Py_True : Py_False; + Py_INCREF(retval); + return retval; +} + +static char +write_bool_internal(PyObject *data, char value) { + if (_check_buffer(data) == 2) + return 2; + + if (_check_size((BufferObject *)data, 1) == 2) + return 2; + char *buf = ((BufferObject *)data)->buf; + buf[((BufferObject *)data)->pos] = value; + ((BufferObject *)data)->pos += 1; + ((BufferObject *)data)->end += 1; + return 1; +} + +static PyObject* +write_bool(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", "value", NULL}; + PyObject *data = NULL; + PyObject *value = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &data, &value)) + return NULL; + if (!PyBool_Check(value)) { + PyErr_SetString(PyExc_TypeError, "value must be a bool"); + return NULL; + } + if (write_bool_internal(data, value == Py_True) == 2) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +static PyObject* +read_str_internal(PyObject *data) { + if (_check_buffer(data) == 2) + return NULL; + + if (_check_read((BufferObject *)data, sizeof(Py_ssize_t)) == 2) + return NULL; + char *buf = ((BufferObject *)data)->buf; + // Read string length. + Py_ssize_t size = *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += sizeof(Py_ssize_t); + if (_check_read((BufferObject *)data, size) == 2) + return NULL; + // Read string content. + PyObject *res = PyUnicode_FromStringAndSize( + buf + ((BufferObject *)data)->pos, (Py_ssize_t)size + ); + if (!res) + return NULL; + ((BufferObject *)data)->pos += size; + return res; +} + +static PyObject* +read_str(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", NULL}; + PyObject *data = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) + return NULL; + return read_str_internal(data); +} + +static char +write_str_internal(PyObject *data, PyObject *value) { + if (_check_buffer(data) == 2) + return 2; + + Py_ssize_t size; + const char *chunk = PyUnicode_AsUTF8AndSize(value, &size); + if (!chunk) + return 2; + Py_ssize_t need = size + sizeof(Py_ssize_t); + if (_check_size((BufferObject *)data, need) == 2) + return 2; + + char *buf = ((BufferObject *)data)->buf; + // Write string length. + *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos) = size; + ((BufferObject *)data)->pos += sizeof(Py_ssize_t); + // Write string content. + memcpy(buf + ((BufferObject *)data)->pos, chunk, size); + ((BufferObject *)data)->pos += size; + ((BufferObject *)data)->end += need; + return 1; +} + +static PyObject* +write_str(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", "value", NULL}; + PyObject *data = NULL; + PyObject *value = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &data, &value)) + return NULL; + if (!PyUnicode_Check(value)) { + PyErr_SetString(PyExc_TypeError, "value must be a str"); + return NULL; + } + if (write_str_internal(data, value) == 2) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +static double +read_float_internal(PyObject *data) { + if (_check_buffer(data) == 2) + return CPY_FLOAT_ERROR; + + if (_check_read((BufferObject *)data, sizeof(double)) == 2) + return CPY_FLOAT_ERROR; + char *buf = ((BufferObject *)data)->buf; + double res = *(double *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += sizeof(double); + return res; +} + +static PyObject* +read_float(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", NULL}; + PyObject *data = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) + return NULL; + double retval = read_float_internal(data); + if (retval == CPY_FLOAT_ERROR && PyErr_Occurred()) { + return NULL; + } + return PyFloat_FromDouble(retval); +} + +static char +write_float_internal(PyObject *data, double value) { + if (_check_buffer(data) == 2) + return 2; + + if (_check_size((BufferObject *)data, sizeof(double)) == 2) + return 2; + char *buf = ((BufferObject *)data)->buf; + *(double *)(buf + ((BufferObject *)data)->pos) = value; + ((BufferObject *)data)->pos += sizeof(double); + ((BufferObject *)data)->end += sizeof(double); + return 1; +} + +static PyObject* +write_float(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", "value", NULL}; + PyObject *data = NULL; + PyObject *value = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &data, &value)) + return NULL; + if (!PyFloat_Check(value)) { + PyErr_SetString(PyExc_TypeError, "value must be a float"); + return NULL; + } + if (write_float_internal(data, PyFloat_AsDouble(value)) == 2) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +static CPyTagged +read_int_internal(PyObject *data) { + if (_check_buffer(data) == 2) + return CPY_INT_TAG; + + if (_check_read((BufferObject *)data, sizeof(CPyTagged)) == 2) + return CPY_INT_TAG; + char *buf = ((BufferObject *)data)->buf; + + CPyTagged ret = *(CPyTagged *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += sizeof(CPyTagged); + if ((ret & CPY_INT_TAG) == 0) + return ret; + // People who have literal ints not fitting in size_t should be punished :-) + PyObject *str_ret = read_str_internal(data); + if (str_ret == NULL) + return CPY_INT_TAG; + PyObject* ret_long = PyLong_FromUnicodeObject(str_ret, 10); + Py_DECREF(str_ret); + return ((CPyTagged)ret_long) | CPY_INT_TAG; +} + +static PyObject* +read_int(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", NULL}; + PyObject *data = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) + return NULL; + CPyTagged retval = read_int_internal(data); + if (retval == CPY_INT_TAG) { + return NULL; + } + return CPyTagged_StealAsObject(retval); +} + +static char +write_int_internal(PyObject *data, CPyTagged value) { + if (_check_buffer(data) == 2) + return 2; + + if (_check_size((BufferObject *)data, sizeof(CPyTagged)) == 2) + return 2; + char *buf = ((BufferObject *)data)->buf; + if ((value & CPY_INT_TAG) == 0) { + *(CPyTagged *)(buf + ((BufferObject *)data)->pos) = value; + } else { + *(CPyTagged *)(buf + ((BufferObject *)data)->pos) = CPY_INT_TAG; + } + ((BufferObject *)data)->pos += sizeof(CPyTagged); + ((BufferObject *)data)->end += sizeof(CPyTagged); + if ((value & CPY_INT_TAG) != 0) { + PyObject *str_value = PyObject_Str(CPyTagged_LongAsObject(value)); + if (str_value == NULL) + return 2; + char res = write_str_internal(data, str_value); + Py_DECREF(str_value); + if (res == 2) + return 2; + } + return 1; +} + +static PyObject* +write_int(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", "value", NULL}; + PyObject *data = NULL; + PyObject *value = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &data, &value)) + return NULL; + if (!PyLong_Check(value)) { + PyErr_SetString(PyExc_TypeError, "value must be an int"); + return NULL; + } + CPyTagged tagged_value = CPyTagged_BorrowFromObject(value); + if (write_int_internal(data, tagged_value) == 2) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +static PyMethodDef native_internal_module_methods[] = { + // TODO: switch public wrappers to METH_FASTCALL. + {"write_bool", (PyCFunction)write_bool, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("write a bool")}, + {"read_bool", (PyCFunction)read_bool, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("read a bool")}, + {"write_str", (PyCFunction)write_str, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("write a string")}, + {"read_str", (PyCFunction)read_str, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("read a string")}, + {"write_float", (PyCFunction)write_float, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("write a float")}, + {"read_float", (PyCFunction)read_float, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("read a float")}, + {"write_int", (PyCFunction)write_int, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("write an int")}, + {"read_int", (PyCFunction)read_int, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("read an int")}, + {NULL, NULL, 0, NULL} +}; + +static int +NativeInternal_ABI_Version(void) { + return NATIVE_INTERNAL_ABI_VERSION; +} + +static int +native_internal_module_exec(PyObject *m) +{ + if (PyType_Ready(&BufferType) < 0) { + return -1; + } + if (PyModule_AddObjectRef(m, "Buffer", (PyObject *) &BufferType) < 0) { + return -1; + } + + // Export mypy internal C API, be careful with the order! + static void *NativeInternal_API[12] = { + (void *)Buffer_internal, + (void *)Buffer_internal_empty, + (void *)Buffer_getvalue_internal, + (void *)write_bool_internal, + (void *)read_bool_internal, + (void *)write_str_internal, + (void *)read_str_internal, + (void *)write_float_internal, + (void *)read_float_internal, + (void *)write_int_internal, + (void *)read_int_internal, + (void *)NativeInternal_ABI_Version, + }; + PyObject *c_api_object = PyCapsule_New((void *)NativeInternal_API, "native_internal._C_API", NULL); + if (PyModule_Add(m, "_C_API", c_api_object) < 0) { + return -1; + } + return 0; +} + +static PyModuleDef_Slot native_internal_module_slots[] = { + {Py_mod_exec, native_internal_module_exec}, +#ifdef Py_MOD_GIL_NOT_USED + {Py_mod_gil, Py_MOD_GIL_NOT_USED}, +#endif + {0, NULL} +}; + +static PyModuleDef native_internal_module = { + .m_base = PyModuleDef_HEAD_INIT, + .m_name = "native_internal", + .m_doc = "Mypy cache serialization utils", + .m_size = 0, + .m_methods = native_internal_module_methods, + .m_slots = native_internal_module_slots, +}; + +PyMODINIT_FUNC +PyInit_native_internal(void) +{ + return PyModuleDef_Init(&native_internal_module); +} diff --git a/mypyc/lib-rt/native_internal.h b/mypyc/lib-rt/native_internal.h new file mode 100644 index 000000000000..3bd3dd1bbb33 --- /dev/null +++ b/mypyc/lib-rt/native_internal.h @@ -0,0 +1,52 @@ +#ifndef NATIVE_INTERNAL_H +#define NATIVE_INTERNAL_H + +#define NATIVE_INTERNAL_ABI_VERSION 0 + +#ifdef NATIVE_INTERNAL_MODULE + +static PyObject *Buffer_internal(PyObject *source); +static PyObject *Buffer_internal_empty(void); +static PyObject *Buffer_getvalue_internal(PyObject *self); +static char write_bool_internal(PyObject *data, char value); +static char read_bool_internal(PyObject *data); +static char write_str_internal(PyObject *data, PyObject *value); +static PyObject *read_str_internal(PyObject *data); +static char write_float_internal(PyObject *data, double value); +static double read_float_internal(PyObject *data); +static char write_int_internal(PyObject *data, CPyTagged value); +static CPyTagged read_int_internal(PyObject *data); +static int NativeInternal_ABI_Version(void); + +#else + +static void **NativeInternal_API; + +#define Buffer_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[0]) +#define Buffer_internal_empty (*(PyObject* (*)(void)) NativeInternal_API[1]) +#define Buffer_getvalue_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[2]) +#define write_bool_internal (*(char (*)(PyObject *source, char value)) NativeInternal_API[3]) +#define read_bool_internal (*(char (*)(PyObject *source)) NativeInternal_API[4]) +#define write_str_internal (*(char (*)(PyObject *source, PyObject *value)) NativeInternal_API[5]) +#define read_str_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[6]) +#define write_float_internal (*(char (*)(PyObject *source, double value)) NativeInternal_API[7]) +#define read_float_internal (*(double (*)(PyObject *source)) NativeInternal_API[8]) +#define write_int_internal (*(char (*)(PyObject *source, CPyTagged value)) NativeInternal_API[9]) +#define read_int_internal (*(CPyTagged (*)(PyObject *source)) NativeInternal_API[10]) +#define NativeInternal_ABI_Version (*(int (*)(void)) NativeInternal_API[11]) + +static int +import_native_internal(void) +{ + NativeInternal_API = (void **)PyCapsule_Import("native_internal._C_API", 0); + if (NativeInternal_API == NULL) + return -1; + if (NativeInternal_ABI_Version() != NATIVE_INTERNAL_ABI_VERSION) { + PyErr_SetString(PyExc_ValueError, "ABI version conflict for native_internal"); + return -1; + } + return 0; +} + +#endif +#endif // NATIVE_INTERNAL_H diff --git a/mypyc/lib-rt/setup.py b/mypyc/lib-rt/setup.py index 1faacc8fc136..5b7a2919c0fd 100644 --- a/mypyc/lib-rt/setup.py +++ b/mypyc/lib-rt/setup.py @@ -12,60 +12,74 @@ from distutils.core import Extension, setup from typing import Any -kwargs: dict[str, Any] -if sys.platform == "darwin": - kwargs = {"language": "c++"} - compile_args = [] -else: - kwargs = {} - compile_args = ["--std=c++11"] +C_APIS_TO_TEST = [ + "init.c", + "int_ops.c", + "float_ops.c", + "list_ops.c", + "exc_ops.c", + "generic_ops.c", + "pythonsupport.c", +] -class build_ext_custom(build_ext): # noqa: N801 - def get_library_names(self): +class BuildExtGtest(build_ext): + def get_library_names(self) -> list[str]: return ["gtest"] - def run(self): + def run(self) -> None: + # Build Google Test, the C++ framework we use for testing C code. + # The source code for Google Test is copied to this repository. gtest_dir = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "external", "googletest") ) - os.makedirs(self.build_temp, exist_ok=True) - - # Build Google Test, the C++ framework we use for testing C code. - # The source code for Google Test is copied to this repository. subprocess.check_call( ["make", "-f", os.path.join(gtest_dir, "make", "Makefile"), f"GTEST_DIR={gtest_dir}"], cwd=self.build_temp, ) - self.library_dirs = [self.build_temp] - return build_ext.run(self) -setup( - name="test_capi", - version="0.1", - ext_modules=[ - Extension( - "test_capi", - [ - "test_capi.cc", - "init.c", - "int_ops.c", - "float_ops.c", - "list_ops.c", - "exc_ops.c", - "generic_ops.c", - "pythonsupport.c", - ], - depends=["CPy.h", "mypyc_util.h", "pythonsupport.h"], - extra_compile_args=["-Wno-unused-function", "-Wno-sign-compare"] + compile_args, - libraries=["gtest"], - include_dirs=["../external/googletest", "../external/googletest/include"], - **kwargs, - ) - ], - cmdclass={"build_ext": build_ext_custom}, -) +if "--run-capi-tests" in sys.argv: + sys.argv.pop() + + kwargs: dict[str, Any] + if sys.platform == "darwin": + kwargs = {"language": "c++"} + compile_args = [] + else: + kwargs = {} + compile_args = ["--std=c++11"] + + setup( + name="test_capi", + version="0.1", + ext_modules=[ + Extension( + "test_capi", + ["test_capi.cc"] + C_APIS_TO_TEST, + depends=["CPy.h", "mypyc_util.h", "pythonsupport.h"], + extra_compile_args=["-Wno-unused-function", "-Wno-sign-compare"] + compile_args, + libraries=["gtest"], + include_dirs=["../external/googletest", "../external/googletest/include"], + **kwargs, + ) + ], + cmdclass={"build_ext": BuildExtGtest}, + ) +else: + # TODO: we need a way to share our preferred C flags and get_extension() logic with + # mypyc/build.py without code duplication. + setup( + name="mypy-native", + version="0.0.1", + ext_modules=[ + Extension( + "native_internal", + ["native_internal.c", "init.c", "int_ops.c", "exc_ops.c", "pythonsupport.c"], + include_dirs=["."], + ) + ], + ) diff --git a/mypyc/options.py b/mypyc/options.py index 50c76d3c0656..c009d3c6a7a4 100644 --- a/mypyc/options.py +++ b/mypyc/options.py @@ -17,6 +17,7 @@ def __init__( strict_dunder_typing: bool = False, group_name: str | None = None, log_trace: bool = False, + depends_on_native_internal: bool = False, ) -> None: self.strip_asserts = strip_asserts self.multi_file = multi_file @@ -50,3 +51,7 @@ def __init__( # mypyc_trace.txt when compiled module is executed. This is useful for # performance analysis. self.log_trace = log_trace + # If enabled, add capsule imports of native_internal API. This should be used + # only for mypy itself, third-party code compiled with mypyc should not use + # native_internal. + self.depends_on_native_internal = depends_on_native_internal diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index e3d59f53ed76..bf42584aef20 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -4,14 +4,18 @@ from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( + KNOWN_NATIVE_TYPES, bit_rprimitive, bool_rprimitive, + bytes_rprimitive, c_int_rprimitive, c_pointer_rprimitive, c_pyssize_t_rprimitive, cstring_rprimitive, dict_rprimitive, + float_rprimitive, int_rprimitive, + none_rprimitive, object_pointer_rprimitive, object_rprimitive, pointer_rprimitive, @@ -24,6 +28,7 @@ custom_primitive_op, function_op, load_address_op, + method_op, ) # Get the 'bool' type object. @@ -326,3 +331,95 @@ return_type=void_rtype, error_kind=ERR_NEVER, ) + +buffer_rprimitive = KNOWN_NATIVE_TYPES["native_internal.Buffer"] + +# Buffer(source) +function_op( + name="native_internal.Buffer", + arg_types=[bytes_rprimitive], + return_type=buffer_rprimitive, + c_function_name="Buffer_internal", + error_kind=ERR_MAGIC, +) + +# Buffer() +function_op( + name="native_internal.Buffer", + arg_types=[], + return_type=buffer_rprimitive, + c_function_name="Buffer_internal_empty", + error_kind=ERR_MAGIC, +) + +method_op( + name="getvalue", + arg_types=[buffer_rprimitive], + return_type=bytes_rprimitive, + c_function_name="Buffer_getvalue_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.write_bool", + arg_types=[object_rprimitive, bool_rprimitive], + return_type=none_rprimitive, + c_function_name="write_bool_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.read_bool", + arg_types=[object_rprimitive], + return_type=bool_rprimitive, + c_function_name="read_bool_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.write_str", + arg_types=[object_rprimitive, str_rprimitive], + return_type=none_rprimitive, + c_function_name="write_str_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.read_str", + arg_types=[object_rprimitive], + return_type=str_rprimitive, + c_function_name="read_str_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.write_float", + arg_types=[object_rprimitive, float_rprimitive], + return_type=none_rprimitive, + c_function_name="write_float_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.read_float", + arg_types=[object_rprimitive], + return_type=float_rprimitive, + c_function_name="read_float_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.write_int", + arg_types=[object_rprimitive, int_rprimitive], + return_type=none_rprimitive, + c_function_name="write_int_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.read_int", + arg_types=[object_rprimitive], + return_type=int_rprimitive, + c_function_name="read_int_internal", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index c7bf5de852a8..da7e074886a2 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1409,6 +1409,55 @@ class TestOverload: def __mypyc_generator_helper__(self, x: Any) -> Any: return x +[case testNativeBufferFastPath] +from native_internal import ( + Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, write_int, read_int +) + +def foo() -> None: + b = Buffer() + write_str(b, "foo") + write_bool(b, True) + write_float(b, 0.1) + write_int(b, 1) + + b = Buffer(b.getvalue()) + x = read_str(b) + y = read_bool(b) + z = read_float(b) + t = read_int(b) +[out] +def foo(): + r0, b :: native_internal.Buffer + r1 :: str + r2, r3, r4, r5 :: None + r6 :: bytes + r7 :: native_internal.Buffer + r8, x :: str + r9, y :: bool + r10, z :: float + r11, t :: int +L0: + r0 = Buffer_internal_empty() + b = r0 + r1 = 'foo' + r2 = write_str_internal(b, r1) + r3 = write_bool_internal(b, 1) + r4 = write_float_internal(b, 0.1) + r5 = write_int_internal(b, 2) + r6 = Buffer_getvalue_internal(b) + r7 = Buffer_internal(r6) + b = r7 + r8 = read_str_internal(b) + x = r8 + r9 = read_bool_internal(b) + y = r9 + r10 = read_float_internal(b) + z = r10 + r11 = read_int_internal(b) + t = r11 + return 1 + [case testEnumFastPath] from enum import Enum diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 1481f3e06871..0f07636c4572 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2710,6 +2710,74 @@ from native import Player [out] Player.MIN = +[case testBufferRoundTrip_native_libs] +from native_internal import ( + Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, write_int, read_int +) + +def test_buffer_basic() -> None: + b = Buffer(b"foo") + assert b.getvalue() == b"foo" + +def test_buffer_roundtrip() -> None: + b = Buffer() + write_str(b, "foo") + write_bool(b, True) + write_str(b, "bar" * 1000) + write_bool(b, False) + write_float(b, 0.1) + write_int(b, 0) + write_int(b, 1) + write_int(b, 2) + write_int(b, 2 ** 85) + + b = Buffer(b.getvalue()) + assert read_str(b) == "foo" + assert read_bool(b) is True + assert read_str(b) == "bar" * 1000 + assert read_bool(b) is False + assert read_float(b) == 0.1 + assert read_int(b) == 0 + assert read_int(b) == 1 + assert read_int(b) == 2 + assert read_int(b) == 2 ** 85 + +[file driver.py] +from native import * + +test_buffer_basic() +test_buffer_roundtrip() + +def test_buffer_basic_interpreted() -> None: + b = Buffer(b"foo") + assert b.getvalue() == b"foo" + +def test_buffer_roundtrip_interpreted() -> None: + b = Buffer() + write_str(b, "foo") + write_bool(b, True) + write_str(b, "bar" * 1000) + write_bool(b, False) + write_float(b, 0.1) + write_int(b, 0) + write_int(b, 1) + write_int(b, 2) + write_int(b, 2 ** 85) + + b = Buffer(b.getvalue()) + assert read_str(b) == "foo" + assert read_bool(b) is True + assert read_str(b) == "bar" * 1000 + assert read_bool(b) is False + assert read_float(b) == 0.1 + assert read_int(b) == 0 + assert read_int(b) == 1 + assert read_int(b) == 2 + assert read_int(b) == 2 ** 85 + +test_buffer_basic_interpreted() +test_buffer_roundtrip_interpreted() + [case testEnumMethodCalls] from enum import Enum from typing import overload, Optional, Union diff --git a/mypyc/test/test_external.py b/mypyc/test/test_external.py index 010c74dee42e..a416cf2ee130 100644 --- a/mypyc/test/test_external.py +++ b/mypyc/test/test_external.py @@ -34,6 +34,7 @@ def test_c_unit_test(self) -> None: "build_ext", f"--build-lib={tmpdir}", f"--build-temp={tmpdir}", + "--run-capi-tests", ], env=env, cwd=os.path.join(base_dir, "mypyc", "lib-rt"), diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 5078426b977f..172a1016dd91 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -86,7 +86,7 @@ setup(name='test_run_output', ext_modules=mypycify({}, separate={}, skip_cgen_input={!r}, strip_asserts=False, - multi_file={}, opt_level='{}'), + multi_file={}, opt_level='{}', install_native_libs={}), ) """ @@ -239,11 +239,13 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> groups = construct_groups(sources, separate, len(module_names) > 1, None) + native_libs = "_native_libs" in testcase.name try: compiler_options = CompilerOptions( multi_file=self.multi_file, separate=self.separate, strict_dunder_typing=self.strict_dunder_typing, + depends_on_native_internal=native_libs, ) result = emitmodule.parse_and_typecheck( sources=sources, @@ -270,14 +272,13 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> check_serialization_roundtrip(ir) opt_level = int(os.environ.get("MYPYC_OPT_LEVEL", 0)) - debug_level = int(os.environ.get("MYPYC_DEBUG_LEVEL", 0)) setup_file = os.path.abspath(os.path.join(WORKDIR, "setup.py")) # We pass the C file information to the build script via setup.py unfortunately with open(setup_file, "w", encoding="utf-8") as f: f.write( setup_format.format( - module_paths, separate, cfiles, self.multi_file, opt_level, debug_level + module_paths, separate, cfiles, self.multi_file, opt_level, native_libs ) ) diff --git a/setup.py b/setup.py index e085b0be3846..798ff4f6c710 100644 --- a/setup.py +++ b/setup.py @@ -154,6 +154,10 @@ def run(self) -> None: # our Appveyor builds run out of memory sometimes. multi_file=sys.platform == "win32" or force_multifile, log_trace=log_trace, + # Mypy itself is allowed to use native_internal extension. + depends_on_native_internal=True, + # TODO: temporary, remove this after we publish mypy-native on PyPI. + install_native_libs=True, ) else: diff --git a/test-data/unit/lib-stub/native_internal.pyi b/test-data/unit/lib-stub/native_internal.pyi new file mode 100644 index 000000000000..bc1f570a8e9c --- /dev/null +++ b/test-data/unit/lib-stub/native_internal.pyi @@ -0,0 +1,12 @@ +class Buffer: + def __init__(self, source: bytes = ...) -> None: ... + def getvalue(self) -> bytes: ... + +def write_bool(data: Buffer, value: bool) -> None: ... +def read_bool(data: Buffer) -> bool: ... +def write_str(data: Buffer, value: str) -> None: ... +def read_str(data: Buffer) -> str: ... +def write_float(data: Buffer, value: float) -> None: ... +def read_float(data: Buffer) -> float: ... +def write_int(data: Buffer, value: int) -> None: ... +def read_int(data: Buffer) -> int: ...