diff --git a/mypy/cache.py b/mypy/cache.py index 49f568c1f3c1..a16a36900c7a 100644 --- a/mypy/cache.py +++ b/mypy/cache.py @@ -10,10 +10,12 @@ read_float as read_float, read_int as read_int, read_str as read_str, + read_tag as read_tag, write_bool as write_bool, write_float as write_float, write_int as write_int, write_str as write_str, + write_tag as write_tag, ) except ImportError: # TODO: temporary, remove this after we publish mypy-native on PyPI. @@ -32,6 +34,12 @@ def read_int(data: Buffer) -> int: def write_int(data: Buffer, value: int) -> None: raise NotImplementedError + def read_tag(data: Buffer) -> int: + raise NotImplementedError + + def write_tag(data: Buffer, value: int) -> None: + raise NotImplementedError + def read_str(data: Buffer) -> str: raise NotImplementedError @@ -59,37 +67,37 @@ def write_float(data: Buffer, value: float) -> None: LITERAL_NONE: Final = 6 -def read_literal(data: Buffer, marker: int) -> int | str | bool | float: - if marker == LITERAL_INT: +def read_literal(data: Buffer, tag: int) -> int | str | bool | float: + if tag == LITERAL_INT: return read_int(data) - elif marker == LITERAL_STR: + elif tag == LITERAL_STR: return read_str(data) - elif marker == LITERAL_BOOL: + elif tag == LITERAL_BOOL: return read_bool(data) - elif marker == LITERAL_FLOAT: + elif tag == LITERAL_FLOAT: return read_float(data) - assert False, f"Unknown literal marker {marker}" + assert False, f"Unknown literal tag {tag}" def write_literal(data: Buffer, value: int | str | bool | float | complex | None) -> None: if isinstance(value, bool): - write_int(data, LITERAL_BOOL) + write_tag(data, LITERAL_BOOL) write_bool(data, value) elif isinstance(value, int): - write_int(data, LITERAL_INT) + write_tag(data, LITERAL_INT) write_int(data, value) elif isinstance(value, str): - write_int(data, LITERAL_STR) + write_tag(data, LITERAL_STR) write_str(data, value) elif isinstance(value, float): - write_int(data, LITERAL_FLOAT) + write_tag(data, LITERAL_FLOAT) write_float(data, value) elif isinstance(value, complex): - write_int(data, LITERAL_COMPLEX) + write_tag(data, LITERAL_COMPLEX) write_float(data, value.real) write_float(data, value.imag) else: - write_int(data, LITERAL_NONE) + write_tag(data, LITERAL_NONE) def read_int_opt(data: Buffer) -> int | None: diff --git a/mypy/nodes.py b/mypy/nodes.py index b9c08f02f316..45e2b60c3e78 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -28,6 +28,7 @@ read_str_list, read_str_opt, read_str_opt_list, + read_tag, write_bool, write_int, write_int_list, @@ -37,6 +38,7 @@ write_str_list, write_str_opt, write_str_opt_list, + write_tag, ) from mypy.options import Options from mypy.util import is_sunder, is_typeshed_file, short_type @@ -417,7 +419,7 @@ def deserialize(cls, data: JsonDict) -> MypyFile: return tree def write(self, data: Buffer) -> None: - write_int(data, MYPY_FILE) + write_tag(data, MYPY_FILE) write_str(data, self._fullname) self.names.write(data, self._fullname) write_bool(data, self.is_stub) @@ -427,7 +429,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> MypyFile: - assert read_int(data) == MYPY_FILE + assert read_tag(data) == MYPY_FILE tree = MypyFile([], []) tree._fullname = read_str(data) tree.names = SymbolTable.read(data) @@ -711,7 +713,7 @@ def deserialize(cls, data: JsonDict) -> OverloadedFuncDef: return res def write(self, data: Buffer) -> None: - write_int(data, OVERLOADED_FUNC_DEF) + write_tag(data, OVERLOADED_FUNC_DEF) write_int(data, len(self.items)) for item in self.items: item.write(data) @@ -1022,7 +1024,7 @@ def deserialize(cls, data: JsonDict) -> FuncDef: return ret def write(self, data: Buffer) -> None: - write_int(data, FUNC_DEF) + write_tag(data, FUNC_DEF) write_str(data, self._name) mypy.types.write_type_opt(data, self.type) write_str(data, self._fullname) @@ -1134,16 +1136,16 @@ def deserialize(cls, data: JsonDict) -> Decorator: return dec def write(self, data: Buffer) -> None: - write_int(data, DECORATOR) + write_tag(data, DECORATOR) self.func.write(data) self.var.write(data) write_bool(data, self.is_overload) @classmethod def read(cls, data: Buffer) -> Decorator: - assert read_int(data) == FUNC_DEF + assert read_tag(data) == FUNC_DEF func = FuncDef.read(data) - assert read_int(data) == VAR + assert read_tag(data) == VAR var = Var.read(data) dec = Decorator(func, [], var) dec.is_overload = read_bool(data) @@ -1326,7 +1328,7 @@ def deserialize(cls, data: JsonDict) -> Var: return v def write(self, data: Buffer) -> None: - write_int(data, VAR) + write_tag(data, VAR) write_str(data, self._name) mypy.types.write_type_opt(data, self.type) mypy.types.write_type_opt(data, self.setter_type) @@ -1341,13 +1343,13 @@ def read(cls, data: Buffer) -> Var: v = Var(name, typ) setter_type: mypy.types.CallableType | None = None if read_bool(data): - assert read_int(data) == mypy.types.CALLABLE_TYPE + assert read_tag(data) == mypy.types.CALLABLE_TYPE setter_type = mypy.types.CallableType.read(data) v.setter_type = setter_type v.is_ready = False # Override True default set in __init__ v._fullname = read_str(data) read_flags(data, v, VAR_FLAGS) - marker = read_int(data) + marker = read_tag(data) if marker == LITERAL_COMPLEX: v.final_value = complex(read_float(data), read_float(data)) elif marker != LITERAL_NONE: @@ -1465,7 +1467,7 @@ def deserialize(cls, data: JsonDict) -> ClassDef: return res def write(self, data: Buffer) -> None: - write_int(data, CLASS_DEF) + write_tag(data, CLASS_DEF) write_str(data, self.name) mypy.types.write_type_list(data, self.type_vars) write_str(data, self.fullname) @@ -2898,7 +2900,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarExpr: ) def write(self, data: Buffer) -> None: - write_int(data, TYPE_VAR_EXPR) + write_tag(data, TYPE_VAR_EXPR) write_str(data, self._name) write_str(data, self._fullname) mypy.types.write_type_list(data, self.values) @@ -2948,7 +2950,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecExpr: ) def write(self, data: Buffer) -> None: - write_int(data, PARAM_SPEC_EXPR) + write_tag(data, PARAM_SPEC_EXPR) write_str(data, self._name) write_str(data, self._fullname) self.upper_bound.write(data) @@ -3016,7 +3018,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr: ) def write(self, data: Buffer) -> None: - write_int(data, TYPE_VAR_TUPLE_EXPR) + write_tag(data, TYPE_VAR_TUPLE_EXPR) self.tuple_fallback.write(data) write_str(data, self._name) write_str(data, self._fullname) @@ -3026,7 +3028,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> TypeVarTupleExpr: - assert read_int(data) == mypy.types.INSTANCE + assert read_tag(data) == mypy.types.INSTANCE fallback = mypy.types.Instance.read(data) return TypeVarTupleExpr( read_str(data), @@ -3908,7 +3910,7 @@ def deserialize(cls, data: JsonDict) -> TypeInfo: return ti def write(self, data: Buffer) -> None: - write_int(data, TYPE_INFO) + write_tag(data, TYPE_INFO) self.names.write(data, self.fullname) self.defn.write(data) write_str(data, self.module_name) @@ -3944,7 +3946,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> TypeInfo: names = SymbolTable.read(data) - assert read_int(data) == CLASS_DEF + assert read_tag(data) == CLASS_DEF defn = ClassDef.read(data) module_name = read_str(data) ti = TypeInfo(names, defn, module_name) @@ -3954,10 +3956,9 @@ def read(cls, data: Buffer) -> TypeInfo: ti.abstract_attributes = list(zip(attrs, statuses)) ti.type_vars = read_str_list(data) ti.has_param_spec_type = read_bool(data) - num_bases = read_int(data) ti.bases = [] - for _ in range(num_bases): - assert read_int(data) == mypy.types.INSTANCE + for _ in range(read_int(data)): + assert read_tag(data) == mypy.types.INSTANCE ti.bases.append(mypy.types.Instance.read(data)) # NOTE: ti.mro will be set in the fixup phase based on these # names. The reason we need to store the mro instead of just @@ -3972,19 +3973,19 @@ def read(cls, data: Buffer) -> TypeInfo: ti._mro_refs = read_str_list(data) ti._promote = cast(list[mypy.types.ProperType], mypy.types.read_type_list(data)) if read_bool(data): - assert read_int(data) == mypy.types.INSTANCE + assert read_tag(data) == mypy.types.INSTANCE ti.alt_promote = mypy.types.Instance.read(data) if read_bool(data): - assert read_int(data) == mypy.types.INSTANCE + assert read_tag(data) == mypy.types.INSTANCE ti.declared_metaclass = mypy.types.Instance.read(data) if read_bool(data): - assert read_int(data) == mypy.types.INSTANCE + assert read_tag(data) == mypy.types.INSTANCE ti.metaclass_type = mypy.types.Instance.read(data) if read_bool(data): - assert read_int(data) == mypy.types.TUPLE_TYPE + assert read_tag(data) == mypy.types.TUPLE_TYPE ti.tuple_type = mypy.types.TupleType.read(data) if read_bool(data): - assert read_int(data) == mypy.types.TYPED_DICT_TYPE + assert read_tag(data) == mypy.types.TYPED_DICT_TYPE ti.typeddict_type = mypy.types.TypedDictType.read(data) read_flags(data, ti, TypeInfo.FLAGS) metadata = read_str(data) @@ -3994,7 +3995,7 @@ def read(cls, data: Buffer) -> TypeInfo: ti.slots = set(read_str_list(data)) ti.deletable_attributes = read_str_list(data) if read_bool(data): - assert read_int(data) == mypy.types.TYPE_VAR_TYPE + assert read_tag(data) == mypy.types.TYPE_VAR_TYPE ti.self_type = mypy.types.TypeVarType.read(data) if read_bool(data): ti.dataclass_transform_spec = DataclassTransformSpec.read(data) @@ -4270,7 +4271,7 @@ def deserialize(cls, data: JsonDict) -> TypeAlias: ) def write(self, data: Buffer) -> None: - write_int(data, TYPE_ALIAS) + write_tag(data, TYPE_ALIAS) write_str(data, self._fullname) self.target.write(data) mypy.types.write_type_list(data, self.alias_tvars) @@ -4890,33 +4891,33 @@ def local_definitions( def read_symbol(data: Buffer) -> mypy.nodes.SymbolNode: - marker = read_int(data) + tag = read_tag(data) # The branches here are ordered manually by type "popularity". - if marker == VAR: + if tag == VAR: return mypy.nodes.Var.read(data) - if marker == FUNC_DEF: + if tag == FUNC_DEF: return mypy.nodes.FuncDef.read(data) - if marker == DECORATOR: + if tag == DECORATOR: return mypy.nodes.Decorator.read(data) - if marker == TYPE_INFO: + if tag == TYPE_INFO: return mypy.nodes.TypeInfo.read(data) - if marker == OVERLOADED_FUNC_DEF: + if tag == OVERLOADED_FUNC_DEF: return mypy.nodes.OverloadedFuncDef.read(data) - if marker == TYPE_VAR_EXPR: + if tag == TYPE_VAR_EXPR: return mypy.nodes.TypeVarExpr.read(data) - if marker == TYPE_ALIAS: + if tag == TYPE_ALIAS: return mypy.nodes.TypeAlias.read(data) - if marker == PARAM_SPEC_EXPR: + if tag == PARAM_SPEC_EXPR: return mypy.nodes.ParamSpecExpr.read(data) - if marker == TYPE_VAR_TUPLE_EXPR: + if tag == TYPE_VAR_TUPLE_EXPR: return mypy.nodes.TypeVarTupleExpr.read(data) - assert False, f"Unknown symbol marker {marker}" + assert False, f"Unknown symbol tag {tag}" def read_overload_part(data: Buffer) -> OverloadPart: - marker = read_int(data) - if marker == DECORATOR: + tag = read_tag(data) + if tag == DECORATOR: return Decorator.read(data) - if marker == FUNC_DEF: + if tag == FUNC_DEF: return FuncDef.read(data) - assert False, f"Invalid marker for an OverloadPart {marker}" + assert False, f"Invalid tag for an OverloadPart {tag}" diff --git a/mypy/types.py b/mypy/types.py index b48e0ef4d985..43e6dafe298e 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -20,6 +20,7 @@ read_str_list, read_str_opt, read_str_opt_list, + read_tag, write_bool, write_int, write_int_list, @@ -28,6 +29,7 @@ write_str_list, write_str_opt, write_str_opt_list, + write_tag, ) from mypy.nodes import ARG_KINDS, ARG_POS, ARG_STAR, ARG_STAR2, INVARIANT, ArgKind, SymbolNode from mypy.options import Options @@ -456,7 +458,7 @@ def deserialize(cls, data: JsonDict) -> TypeAliasType: return alias def write(self, data: Buffer) -> None: - write_int(data, TYPE_ALIAS_TYPE) + write_tag(data, TYPE_ALIAS_TYPE) write_type_list(data, self.args) assert self.alias is not None write_str(data, self.alias.fullname) @@ -735,7 +737,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarType: ) def write(self, data: Buffer) -> None: - write_int(data, TYPE_VAR_TYPE) + write_tag(data, TYPE_VAR_TYPE) write_str(data, self.name) write_str(data, self.fullname) write_int(data, self.id.raw_id) @@ -887,7 +889,7 @@ def deserialize(cls, data: JsonDict) -> ParamSpecType: ) def write(self, data: Buffer) -> None: - write_int(data, PARAM_SPEC_TYPE) + write_tag(data, PARAM_SPEC_TYPE) self.prefix.write(data) write_str(data, self.name) write_str(data, self.fullname) @@ -899,7 +901,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> ParamSpecType: - assert read_int(data) == PARAMETERS + assert read_tag(data) == PARAMETERS prefix = Parameters.read(data) return ParamSpecType( read_str(data), @@ -967,7 +969,7 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleType: ) def write(self, data: Buffer) -> None: - write_int(data, TYPE_VAR_TUPLE_TYPE) + write_tag(data, TYPE_VAR_TUPLE_TYPE) self.tuple_fallback.write(data) write_str(data, self.name) write_str(data, self.fullname) @@ -979,7 +981,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> TypeVarTupleType: - assert read_int(data) == INSTANCE + assert read_tag(data) == INSTANCE fallback = Instance.read(data) return TypeVarTupleType( read_str(data), @@ -1123,7 +1125,7 @@ def deserialize(cls, data: JsonDict) -> UnboundType: ) def write(self, data: Buffer) -> None: - write_int(data, UNBOUND_TYPE) + write_tag(data, UNBOUND_TYPE) write_str(data, self.name) write_type_list(data, self.args) write_str_opt(data, self.original_str_expr) @@ -1233,7 +1235,7 @@ def serialize(self) -> JsonDict: return {".class": "UnpackType", "type": self.type.serialize()} def write(self, data: Buffer) -> None: - write_int(data, UNPACK_TYPE) + write_tag(data, UNPACK_TYPE) self.type.write(data) @classmethod @@ -1342,7 +1344,7 @@ def deserialize(cls, data: JsonDict) -> AnyType: ) def write(self, data: Buffer) -> None: - write_int(data, ANY_TYPE) + write_tag(data, ANY_TYPE) write_type_opt(data, self.source_any) write_int(data, self.type_of_any) write_str_opt(data, self.missing_import_name) @@ -1350,7 +1352,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> AnyType: if read_bool(data): - assert read_int(data) == ANY_TYPE + assert read_tag(data) == ANY_TYPE source_any = AnyType.read(data) else: source_any = None @@ -1403,7 +1405,7 @@ def deserialize(cls, data: JsonDict) -> UninhabitedType: return UninhabitedType() def write(self, data: Buffer) -> None: - write_int(data, UNINHABITED_TYPE) + write_tag(data, UNINHABITED_TYPE) @classmethod def read(cls, data: Buffer) -> UninhabitedType: @@ -1442,7 +1444,7 @@ def deserialize(cls, data: JsonDict) -> NoneType: return NoneType() def write(self, data: Buffer) -> None: - write_int(data, NONE_TYPE) + write_tag(data, NONE_TYPE) @classmethod def read(cls, data: Buffer) -> NoneType: @@ -1496,7 +1498,7 @@ def deserialize(cls, data: JsonDict) -> DeletedType: return DeletedType(data["source"]) def write(self, data: Buffer) -> None: - write_int(data, DELETED_TYPE) + write_tag(data, DELETED_TYPE) write_str_opt(data, self.source) @classmethod @@ -1704,7 +1706,7 @@ def deserialize(cls, data: JsonDict | str) -> Instance: return inst def write(self, data: Buffer) -> None: - write_int(data, INSTANCE) + write_tag(data, INSTANCE) write_str(data, self.type.fullname) write_type_list(data, self.args) write_type_opt(data, self.last_known_value) @@ -1720,7 +1722,7 @@ def read(cls, data: Buffer) -> Instance: inst = Instance(NOT_READY, read_type_list(data)) inst.type_ref = type_ref if read_bool(data): - assert read_int(data) == LITERAL_TYPE + assert read_tag(data) == LITERAL_TYPE inst.last_known_value = LiteralType.read(data) if read_bool(data): inst.extra_attrs = ExtraAttrs.read(data) @@ -2003,7 +2005,7 @@ def deserialize(cls, data: JsonDict) -> Parameters: ) def write(self, data: Buffer) -> None: - write_int(data, PARAMETERS) + write_tag(data, PARAMETERS) write_type_list(data, self.arg_types) write_int_list(data, [int(x.value) for x in self.arg_kinds]) write_str_opt_list(data, self.arg_names) @@ -2525,7 +2527,7 @@ def deserialize(cls, data: JsonDict) -> CallableType: ) def write(self, data: Buffer) -> None: - write_int(data, CALLABLE_TYPE) + write_tag(data, CALLABLE_TYPE) self.fallback.write(data) write_type_list(data, self.arg_types) write_int_list(data, [int(x.value) for x in self.arg_kinds]) @@ -2544,7 +2546,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> CallableType: - assert read_int(data) == INSTANCE + assert read_tag(data) == INSTANCE fallback = Instance.read(data) return CallableType( read_type_list(data), @@ -2640,15 +2642,14 @@ def deserialize(cls, data: JsonDict) -> Overloaded: return Overloaded([CallableType.deserialize(t) for t in data["items"]]) def write(self, data: Buffer) -> None: - write_int(data, OVERLOADED) + write_tag(data, OVERLOADED) write_type_list(data, self.items) @classmethod def read(cls, data: Buffer) -> Overloaded: items = [] - num_overloads = read_int(data) - for _ in range(num_overloads): - assert read_int(data) == CALLABLE_TYPE + for _ in range(read_int(data)): + assert read_tag(data) == CALLABLE_TYPE items.append(CallableType.read(data)) return Overloaded(items) @@ -2749,14 +2750,14 @@ def deserialize(cls, data: JsonDict) -> TupleType: ) def write(self, data: Buffer) -> None: - write_int(data, TUPLE_TYPE) + write_tag(data, TUPLE_TYPE) self.partial_fallback.write(data) write_type_list(data, self.items) write_bool(data, self.implicit) @classmethod def read(cls, data: Buffer) -> TupleType: - assert read_int(data) == INSTANCE + assert read_tag(data) == INSTANCE fallback = Instance.read(data) return TupleType(read_type_list(data), fallback, implicit=read_bool(data)) @@ -2931,7 +2932,7 @@ def deserialize(cls, data: JsonDict) -> TypedDictType: ) def write(self, data: Buffer) -> None: - write_int(data, TYPED_DICT_TYPE) + write_tag(data, TYPED_DICT_TYPE) self.fallback.write(data) write_type_map(data, self.items) write_str_list(data, sorted(self.required_keys)) @@ -2939,7 +2940,7 @@ def write(self, data: Buffer) -> None: @classmethod def read(cls, data: Buffer) -> TypedDictType: - assert read_int(data) == INSTANCE + assert read_tag(data) == INSTANCE fallback = Instance.read(data) return TypedDictType( read_type_map(data), set(read_str_list(data)), set(read_str_list(data)), fallback @@ -3194,16 +3195,16 @@ def deserialize(cls, data: JsonDict) -> LiteralType: return LiteralType(value=data["value"], fallback=Instance.deserialize(data["fallback"])) def write(self, data: Buffer) -> None: - write_int(data, LITERAL_TYPE) + write_tag(data, LITERAL_TYPE) self.fallback.write(data) write_literal(data, self.value) @classmethod def read(cls, data: Buffer) -> LiteralType: - assert read_int(data) == INSTANCE + assert read_tag(data) == INSTANCE fallback = Instance.read(data) - marker = read_int(data) - return LiteralType(read_literal(data, marker), fallback) + tag = read_tag(data) + return LiteralType(read_literal(data, tag), fallback) def is_singleton_type(self) -> bool: return self.is_enum_literal() or isinstance(self.value, bool) @@ -3307,7 +3308,7 @@ def deserialize(cls, data: JsonDict) -> UnionType: ) def write(self, data: Buffer) -> None: - write_int(data, UNION_TYPE) + write_tag(data, UNION_TYPE) write_type_list(data, self.items) write_bool(data, self.uses_pep604_syntax) @@ -3452,7 +3453,7 @@ def deserialize(cls, data: JsonDict) -> Type: return TypeType.make_normalized(deserialize_type(data["item"])) def write(self, data: Buffer) -> None: - write_int(data, TYPE_TYPE) + write_tag(data, TYPE_TYPE) self.item.write(data) @classmethod @@ -4141,67 +4142,67 @@ def type_vars_as_args(type_vars: Sequence[TypeVarLikeType]) -> tuple[Type, ...]: def read_type(data: Buffer) -> Type: - marker = read_int(data) + tag = read_tag(data) # The branches here are ordered manually by type "popularity". - if marker == INSTANCE: + if tag == INSTANCE: return Instance.read(data) - if marker == ANY_TYPE: + if tag == ANY_TYPE: return AnyType.read(data) - if marker == TYPE_VAR_TYPE: + if tag == TYPE_VAR_TYPE: return TypeVarType.read(data) - if marker == CALLABLE_TYPE: + if tag == CALLABLE_TYPE: return CallableType.read(data) - if marker == NONE_TYPE: + if tag == NONE_TYPE: return NoneType.read(data) - if marker == UNION_TYPE: + if tag == UNION_TYPE: return UnionType.read(data) - if marker == LITERAL_TYPE: + if tag == LITERAL_TYPE: return LiteralType.read(data) - if marker == TYPE_ALIAS_TYPE: + if tag == TYPE_ALIAS_TYPE: return TypeAliasType.read(data) - if marker == TUPLE_TYPE: + if tag == TUPLE_TYPE: return TupleType.read(data) - if marker == TYPED_DICT_TYPE: + if tag == TYPED_DICT_TYPE: return TypedDictType.read(data) - if marker == TYPE_TYPE: + if tag == TYPE_TYPE: return TypeType.read(data) - if marker == OVERLOADED: + if tag == OVERLOADED: return Overloaded.read(data) - if marker == PARAM_SPEC_TYPE: + if tag == PARAM_SPEC_TYPE: return ParamSpecType.read(data) - if marker == TYPE_VAR_TUPLE_TYPE: + if tag == TYPE_VAR_TUPLE_TYPE: return TypeVarTupleType.read(data) - if marker == UNPACK_TYPE: + if tag == UNPACK_TYPE: return UnpackType.read(data) - if marker == PARAMETERS: + if tag == PARAMETERS: return Parameters.read(data) - if marker == UNINHABITED_TYPE: + if tag == UNINHABITED_TYPE: return UninhabitedType.read(data) - if marker == UNBOUND_TYPE: + if tag == UNBOUND_TYPE: return UnboundType.read(data) - if marker == DELETED_TYPE: + if tag == DELETED_TYPE: return DeletedType.read(data) - assert False, f"Unknown type marker {marker}" + assert False, f"Unknown type tag {tag}" def read_function_like(data: Buffer) -> FunctionLike: - marker = read_int(data) - if marker == CALLABLE_TYPE: + tag = read_tag(data) + if tag == CALLABLE_TYPE: return CallableType.read(data) - if marker == OVERLOADED: + if tag == OVERLOADED: return Overloaded.read(data) - assert False, f"Invalid type marker for FunctionLike {marker}" + assert False, f"Invalid type tag for FunctionLike {tag}" def read_type_var_like(data: Buffer) -> TypeVarLikeType: - marker = read_int(data) - if marker == TYPE_VAR_TYPE: + tag = read_tag(data) + if tag == TYPE_VAR_TYPE: return TypeVarType.read(data) - if marker == PARAM_SPEC_TYPE: + if tag == PARAM_SPEC_TYPE: return ParamSpecType.read(data) - if marker == TYPE_VAR_TUPLE_TYPE: + if tag == TYPE_VAR_TUPLE_TYPE: return TypeVarTupleType.read(data) - assert False, f"Invalid type marker for TypeVarLikeType {marker}" + assert False, f"Invalid type tag for TypeVarLikeType {tag}" def read_type_opt(data: Buffer) -> Type | None: diff --git a/mypy/typeshed/stubs/mypy-native/native_internal.pyi b/mypy/typeshed/stubs/mypy-native/native_internal.pyi index bc1f570a8e9c..3c6a22c938e3 100644 --- a/mypy/typeshed/stubs/mypy-native/native_internal.pyi +++ b/mypy/typeshed/stubs/mypy-native/native_internal.pyi @@ -10,3 +10,5 @@ def write_float(data: Buffer, value: float) -> None: ... def read_float(data: Buffer) -> float: ... def write_int(data: Buffer, value: int) -> None: ... def read_int(data: Buffer) -> int: ... +def write_tag(data: Buffer, value: int) -> None: ... +def read_tag(data: Buffer) -> int: ... diff --git a/mypyc/lib-rt/native_internal.c b/mypyc/lib-rt/native_internal.c index 11a3fafee56f..1c35eab946f8 100644 --- a/mypyc/lib-rt/native_internal.c +++ b/mypyc/lib-rt/native_internal.c @@ -1,10 +1,12 @@ #define PY_SSIZE_T_CLEAN #include +#include #include "CPy.h" #define NATIVE_INTERNAL_MODULE #include "native_internal.h" #define START_SIZE 512 +#define MAX_SHORT_INT_TAGGED (255 << 1) typedef struct { PyObject_HEAD @@ -436,6 +438,71 @@ write_int(PyObject *self, PyObject *args, PyObject *kwds) { return Py_None; } +static CPyTagged +read_tag_internal(PyObject *data) { + if (_check_buffer(data) == 2) + return CPY_INT_TAG; + + if (_check_read((BufferObject *)data, 1) == 2) + return CPY_INT_TAG; + char *buf = ((BufferObject *)data)->buf; + + uint8_t ret = *(uint8_t *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += 1; + return ((CPyTagged)ret) << 1; +} + +static PyObject* +read_tag(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", NULL}; + PyObject *data = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data)) + return NULL; + CPyTagged retval = read_tag_internal(data); + if (retval == CPY_INT_TAG) { + return NULL; + } + return CPyTagged_StealAsObject(retval); +} + +static char +write_tag_internal(PyObject *data, CPyTagged value) { + if (_check_buffer(data) == 2) + return 2; + + if (value > MAX_SHORT_INT_TAGGED) { + PyErr_SetString(PyExc_OverflowError, "value must fit in single byte"); + return 2; + } + + if (_check_size((BufferObject *)data, 1) == 2) + return 2; + uint8_t *buf = (uint8_t *)((BufferObject *)data)->buf; + *(buf + ((BufferObject *)data)->pos) = (uint8_t)(value >> 1); + ((BufferObject *)data)->pos += 1; + ((BufferObject *)data)->end += 1; + return 1; +} + +static PyObject* +write_tag(PyObject *self, PyObject *args, PyObject *kwds) { + static char *kwlist[] = {"data", "value", NULL}; + PyObject *data = NULL; + PyObject *value = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &data, &value)) + return NULL; + if (!PyLong_Check(value)) { + PyErr_SetString(PyExc_TypeError, "value must be an int"); + return NULL; + } + CPyTagged tagged_value = CPyTagged_BorrowFromObject(value); + if (write_tag_internal(data, tagged_value) == 2) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + static PyMethodDef native_internal_module_methods[] = { // TODO: switch public wrappers to METH_FASTCALL. {"write_bool", (PyCFunction)write_bool, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("write a bool")}, @@ -446,6 +513,8 @@ static PyMethodDef native_internal_module_methods[] = { {"read_float", (PyCFunction)read_float, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("read a float")}, {"write_int", (PyCFunction)write_int, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("write an int")}, {"read_int", (PyCFunction)read_int, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("read an int")}, + {"write_tag", (PyCFunction)write_tag, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("write a short int")}, + {"read_tag", (PyCFunction)read_tag, METH_VARARGS | METH_KEYWORDS, PyDoc_STR("read a short int")}, {NULL, NULL, 0, NULL} }; @@ -465,7 +534,7 @@ native_internal_module_exec(PyObject *m) } // Export mypy internal C API, be careful with the order! - static void *NativeInternal_API[12] = { + static void *NativeInternal_API[14] = { (void *)Buffer_internal, (void *)Buffer_internal_empty, (void *)Buffer_getvalue_internal, @@ -477,6 +546,8 @@ native_internal_module_exec(PyObject *m) (void *)read_float_internal, (void *)write_int_internal, (void *)read_int_internal, + (void *)write_tag_internal, + (void *)read_tag_internal, (void *)NativeInternal_ABI_Version, }; PyObject *c_api_object = PyCapsule_New((void *)NativeInternal_API, "native_internal._C_API", NULL); diff --git a/mypyc/lib-rt/native_internal.h b/mypyc/lib-rt/native_internal.h index 3bd3dd1bbb33..5a8905f0e6f0 100644 --- a/mypyc/lib-rt/native_internal.h +++ b/mypyc/lib-rt/native_internal.h @@ -16,6 +16,8 @@ static char write_float_internal(PyObject *data, double value); static double read_float_internal(PyObject *data); static char write_int_internal(PyObject *data, CPyTagged value); static CPyTagged read_int_internal(PyObject *data); +static char write_tag_internal(PyObject *data, CPyTagged value); +static CPyTagged read_tag_internal(PyObject *data); static int NativeInternal_ABI_Version(void); #else @@ -33,7 +35,9 @@ static void **NativeInternal_API; #define read_float_internal (*(double (*)(PyObject *source)) NativeInternal_API[8]) #define write_int_internal (*(char (*)(PyObject *source, CPyTagged value)) NativeInternal_API[9]) #define read_int_internal (*(CPyTagged (*)(PyObject *source)) NativeInternal_API[10]) -#define NativeInternal_ABI_Version (*(int (*)(void)) NativeInternal_API[11]) +#define write_tag_internal (*(char (*)(PyObject *source, CPyTagged value)) NativeInternal_API[11]) +#define read_tag_internal (*(CPyTagged (*)(PyObject *source)) NativeInternal_API[12]) +#define NativeInternal_ABI_Version (*(int (*)(void)) NativeInternal_API[13]) static int import_native_internal(void) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 8738255081e2..5875d5d65e9b 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -423,3 +423,19 @@ c_function_name="read_int_internal", error_kind=ERR_MAGIC, ) + +function_op( + name="native_internal.write_tag", + arg_types=[object_rprimitive, int_rprimitive], + return_type=none_rprimitive, + c_function_name="write_tag_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="native_internal.read_tag", + arg_types=[object_rprimitive], + return_type=int_rprimitive, + c_function_name="read_tag_internal", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 68bc18c7bdeb..3a9657d49f34 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1411,7 +1411,8 @@ class TestOverload: [case testNativeBufferFastPath] from native_internal import ( - Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, write_int, read_int + Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, + write_int, read_int, write_tag, read_tag ) def foo() -> None: @@ -1420,23 +1421,25 @@ def foo() -> None: write_bool(b, True) write_float(b, 0.1) write_int(b, 1) + write_tag(b, 1) b = Buffer(b.getvalue()) x = read_str(b) y = read_bool(b) z = read_float(b) t = read_int(b) + u = read_tag(b) [out] def foo(): r0, b :: native_internal.Buffer r1 :: str - r2, r3, r4, r5 :: None - r6 :: bytes - r7 :: native_internal.Buffer - r8, x :: str - r9, y :: bool - r10, z :: float - r11, t :: int + r2, r3, r4, r5, r6 :: None + r7 :: bytes + r8 :: native_internal.Buffer + r9, x :: str + r10, y :: bool + r11, z :: float + r12, t, r13, u :: int L0: r0 = Buffer_internal_empty() b = r0 @@ -1445,17 +1448,20 @@ L0: r3 = write_bool_internal(b, 1) r4 = write_float_internal(b, 0.1) r5 = write_int_internal(b, 2) - r6 = Buffer_getvalue_internal(b) - r7 = Buffer_internal(r6) - b = r7 - r8 = read_str_internal(b) - x = r8 - r9 = read_bool_internal(b) - y = r9 - r10 = read_float_internal(b) - z = r10 - r11 = read_int_internal(b) - t = r11 + r6 = write_tag_internal(b, 2) + r7 = Buffer_getvalue_internal(b) + r8 = Buffer_internal(r7) + b = r8 + r9 = read_str_internal(b) + x = r9 + r10 = read_bool_internal(b) + y = r10 + r11 = read_float_internal(b) + z = r11 + r12 = read_int_internal(b) + t = r12 + r13 = read_tag_internal(b) + u = r13 return 1 [case testEnumFastPath] diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 6f1217bd36e6..dc64680f67c1 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2712,7 +2712,8 @@ Player.MIN = [case testBufferRoundTrip_native_libs] from native_internal import ( - Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, write_int, read_int + Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, + write_int, read_int, write_tag, read_tag ) def test_buffer_basic() -> None: @@ -2728,8 +2729,11 @@ def test_buffer_roundtrip() -> None: write_float(b, 0.1) write_int(b, 0) write_int(b, 1) + write_tag(b, 33) + write_tag(b, 255) write_int(b, 2) write_int(b, 2 ** 85) + write_int(b, -1) b = Buffer(b.getvalue()) assert read_str(b) == "foo" @@ -2739,8 +2743,11 @@ def test_buffer_roundtrip() -> None: assert read_float(b) == 0.1 assert read_int(b) == 0 assert read_int(b) == 1 + assert read_tag(b) == 33 + assert read_tag(b) == 255 assert read_int(b) == 2 assert read_int(b) == 2 ** 85 + assert read_int(b) == -1 [file driver.py] from native import * @@ -2761,8 +2768,11 @@ def test_buffer_roundtrip_interpreted() -> None: write_float(b, 0.1) write_int(b, 0) write_int(b, 1) + write_tag(b, 33) + write_tag(b, 255) write_int(b, 2) write_int(b, 2 ** 85) + write_int(b, -1) b = Buffer(b.getvalue()) assert read_str(b) == "foo" @@ -2772,8 +2782,11 @@ def test_buffer_roundtrip_interpreted() -> None: assert read_float(b) == 0.1 assert read_int(b) == 0 assert read_int(b) == 1 + assert read_tag(b) == 33 + assert read_tag(b) == 255 assert read_int(b) == 2 assert read_int(b) == 2 ** 85 + assert read_int(b) == -1 test_buffer_basic_interpreted() test_buffer_roundtrip_interpreted() diff --git a/test-data/unit/lib-stub/native_internal.pyi b/test-data/unit/lib-stub/native_internal.pyi index bc1f570a8e9c..3c6a22c938e3 100644 --- a/test-data/unit/lib-stub/native_internal.pyi +++ b/test-data/unit/lib-stub/native_internal.pyi @@ -10,3 +10,5 @@ def write_float(data: Buffer, value: float) -> None: ... def read_float(data: Buffer) -> float: ... def write_int(data: Buffer, value: int) -> None: ... def read_int(data: Buffer) -> int: ... +def write_tag(data: Buffer, value: int) -> None: ... +def read_tag(data: Buffer) -> int: ...