diff --git a/mypy/types.py b/mypy/types.py index 8d5648ae0bda..e0265e601e0c 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1674,7 +1674,7 @@ def __eq__(self, other: object) -> bool: def serialize(self) -> JsonDict | str: assert self.type is not None type_ref = self.type.fullname - if not self.args and not self.last_known_value: + if not self.args and not self.last_known_value and not self.extra_attrs: return type_ref data: JsonDict = { ".class": "Instance", @@ -1745,7 +1745,6 @@ def copy_modified( ), extra_attrs=self.extra_attrs, ) - # We intentionally don't copy the extra_attrs here, so they will be erased. new.can_be_true = self.can_be_true new.can_be_false = self.can_be_false return new diff --git a/mypyc/lib-rt/native_internal.c b/mypyc/lib-rt/native_internal.c index 3228f0330793..1c211464b19c 100644 --- a/mypyc/lib-rt/native_internal.c +++ b/mypyc/lib-rt/native_internal.c @@ -8,6 +8,14 @@ #define START_SIZE 512 #define MAX_SHORT_INT_TAGGED (255 << 1) +#define MAX_SHORT_LEN 127 +#define LONG_STR_TAG 1 + +#define MIN_SHORT_INT -10 +#define MAX_SHORT_INT 117 +#define MEDIUM_INT_TAG 1 +#define LONG_INT_TAG 3 + typedef struct { PyObject_HEAD Py_ssize_t pos; @@ -166,6 +174,12 @@ _check_read(BufferObject *data, Py_ssize_t need) { return 1; } +/* +bool format: single byte + \x00 - False + \x01 - True +*/ + static char read_bool_internal(PyObject *data) { if (_check_buffer(data) == 2) @@ -225,20 +239,34 @@ write_bool(PyObject *self, PyObject *args, PyObject *kwds) { return Py_None; } +/* +str format: size followed by UTF-8 bytes + short strings (len <= 127): single byte for size as `(uint8_t)size << 1` + long strings: \x01 followed by size as Py_ssize_t +*/ + static PyObject* read_str_internal(PyObject *data) { if (_check_buffer(data) == 2) return NULL; - if (_check_read((BufferObject *)data, sizeof(Py_ssize_t)) == 2) - return NULL; + Py_ssize_t size; char *buf = ((BufferObject *)data)->buf; // Read string length. - Py_ssize_t size = *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos); - ((BufferObject *)data)->pos += sizeof(Py_ssize_t); - if (_check_read((BufferObject *)data, size) == 2) + if (_check_read((BufferObject *)data, 1) == 2) return NULL; + uint8_t first = *(uint8_t *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += 1; + if (first != LONG_STR_TAG) { + // Common case: short string (len <= 127). + size = (Py_ssize_t)(first >> 1); + } else { + size = *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += sizeof(Py_ssize_t); + } // Read string content. + if (_check_read((BufferObject *)data, size) == 2) + return NULL; PyObject *res = PyUnicode_FromStringAndSize( buf + ((BufferObject *)data)->pos, (Py_ssize_t)size ); @@ -266,14 +294,28 @@ write_str_internal(PyObject *data, PyObject *value) { const char *chunk = PyUnicode_AsUTF8AndSize(value, &size); if (!chunk) return 2; - Py_ssize_t need = size + sizeof(Py_ssize_t); - if (_check_size((BufferObject *)data, need) == 2) - return 2; - char *buf = ((BufferObject *)data)->buf; + Py_ssize_t need; + char *buf; // Write string length. - *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos) = size; - ((BufferObject *)data)->pos += sizeof(Py_ssize_t); + if (size <= MAX_SHORT_LEN) { + // Common case: short string (len <= 127) store as single byte. + need = size + 1; + if (_check_size((BufferObject *)data, need) == 2) + return 2; + buf = ((BufferObject *)data)->buf; + *(uint8_t *)(buf + ((BufferObject *)data)->pos) = (uint8_t)size << 1; + ((BufferObject *)data)->pos += 1; + } else { + need = size + sizeof(Py_ssize_t) + 1; + if (_check_size((BufferObject *)data, need) == 2) + return 2; + buf = ((BufferObject *)data)->buf; + *(uint8_t *)(buf + ((BufferObject *)data)->pos) = LONG_STR_TAG; + ((BufferObject *)data)->pos += 1; + *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos) = size; + ((BufferObject *)data)->pos += sizeof(Py_ssize_t); + } // Write string content. memcpy(buf + ((BufferObject *)data)->pos, chunk, size); ((BufferObject *)data)->pos += size; @@ -299,6 +341,11 @@ write_str(PyObject *self, PyObject *args, PyObject *kwds) { return Py_None; } +/* +float format: + stored as a C double +*/ + static double read_float_internal(PyObject *data) { if (_check_buffer(data) == 2) @@ -357,19 +404,33 @@ write_float(PyObject *self, PyObject *args, PyObject *kwds) { return Py_None; } +/* +int format: + most common values (-10 <= value <= 117): single byte as `(uint8_t)(value + 10) << 1` + medium values (fit in CPyTagged): \x01 followed by CPyTagged value + long values (very rare): \x03 followed by decimal string (see str format) +*/ + static CPyTagged read_int_internal(PyObject *data) { if (_check_buffer(data) == 2) return CPY_INT_TAG; - if (_check_read((BufferObject *)data, sizeof(CPyTagged)) == 2) - return CPY_INT_TAG; char *buf = ((BufferObject *)data)->buf; + if (_check_read((BufferObject *)data, 1) == 2) + return CPY_INT_TAG; - CPyTagged ret = *(CPyTagged *)(buf + ((BufferObject *)data)->pos); - ((BufferObject *)data)->pos += sizeof(CPyTagged); - if ((ret & CPY_INT_TAG) == 0) + uint8_t first = *(uint8_t *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += 1; + if ((first & MEDIUM_INT_TAG) == 0) { + // Most common case: int that is small in absolute value. + return ((Py_ssize_t)(first >> 1) + MIN_SHORT_INT) << 1; + } + if (first == MEDIUM_INT_TAG) { + CPyTagged ret = *(CPyTagged *)(buf + ((BufferObject *)data)->pos); + ((BufferObject *)data)->pos += sizeof(CPyTagged); return ret; + } // People who have literal ints not fitting in size_t should be punished :-) PyObject *str_ret = read_str_internal(data); if (str_ret == NULL) @@ -397,17 +458,34 @@ write_int_internal(PyObject *data, CPyTagged value) { if (_check_buffer(data) == 2) return 2; - if (_check_size((BufferObject *)data, sizeof(CPyTagged)) == 2) - return 2; - char *buf = ((BufferObject *)data)->buf; + char *buf; if ((value & CPY_INT_TAG) == 0) { - *(CPyTagged *)(buf + ((BufferObject *)data)->pos) = value; + Py_ssize_t real_value = CPyTagged_ShortAsSsize_t(value); + if (real_value >= MIN_SHORT_INT && real_value <= MAX_SHORT_INT) { + // Most common case: int that is small in absolute value. + if (_check_size((BufferObject *)data, 1) == 2) + return 2; + buf = ((BufferObject *)data)->buf; + *(uint8_t *)(buf + ((BufferObject *)data)->pos) = (uint8_t)(real_value - MIN_SHORT_INT) << 1; + ((BufferObject *)data)->pos += 1; + ((BufferObject *)data)->end += 1; + } else { + if (_check_size((BufferObject *)data, sizeof(CPyTagged) + 1) == 2) + return 2; + buf = ((BufferObject *)data)->buf; + *(uint8_t *)(buf + ((BufferObject *)data)->pos) = MEDIUM_INT_TAG; + ((BufferObject *)data)->pos += 1; + *(CPyTagged *)(buf + ((BufferObject *)data)->pos) = value; + ((BufferObject *)data)->pos += sizeof(CPyTagged); + ((BufferObject *)data)->end += sizeof(CPyTagged) + 1; + } } else { - *(CPyTagged *)(buf + ((BufferObject *)data)->pos) = CPY_INT_TAG; - } - ((BufferObject *)data)->pos += sizeof(CPyTagged); - ((BufferObject *)data)->end += sizeof(CPyTagged); - if ((value & CPY_INT_TAG) != 0) { + if (_check_size((BufferObject *)data, 1) == 2) + return 2; + buf = ((BufferObject *)data)->buf; + *(uint8_t *)(buf + ((BufferObject *)data)->pos) = LONG_INT_TAG; + ((BufferObject *)data)->pos += 1; + ((BufferObject *)data)->end += 1; PyObject *str_value = PyObject_Str(CPyTagged_LongAsObject(value)); if (str_value == NULL) return 2; @@ -438,6 +516,11 @@ write_int(PyObject *self, PyObject *args, PyObject *kwds) { return Py_None; } +/* +integer tag format (0 <= t <= 255): + stored as a uint8_t +*/ + static uint8_t read_tag_internal(PyObject *data) { if (_check_buffer(data) == 2) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 943f6fc04b72..8e6e450c64dc 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -430,7 +430,7 @@ arg_types=[object_rprimitive, uint8_rprimitive], return_type=none_rprimitive, c_function_name="write_tag_internal", - error_kind=ERR_MAGIC_OVERLAPPING, + error_kind=ERR_MAGIC, ) function_op( diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index edc989ea641c..2e55ee70687e 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2741,7 +2741,11 @@ def test_buffer_roundtrip() -> None: write_tag(b, TAG_B) write_int(b, 2) write_int(b, 2 ** 85) + write_int(b, 255) write_int(b, -1) + write_int(b, -255) + write_int(b, 1234512344) + write_int(b, 1234512345) b = Buffer(b.getvalue()) assert read_str(b) == "foo" @@ -2756,13 +2760,41 @@ def test_buffer_roundtrip() -> None: assert read_tag(b) == TAG_B assert read_int(b) == 2 assert read_int(b) == 2 ** 85 + assert read_int(b) == 255 assert read_int(b) == -1 + assert read_int(b) == -255 + assert read_int(b) == 1234512344 + assert read_int(b) == 1234512345 + +def test_buffer_int_size() -> None: + for i in (-10, -9, 0, 116, 117): + b = Buffer() + write_int(b, i) + assert len(b.getvalue()) == 1 + b = Buffer(b.getvalue()) + assert read_int(b) == i + for i in (-12345, -12344, -11, 118, 12344, 12345): + b = Buffer() + write_int(b, i) + assert len(b.getvalue()) <= 9 # sizeof(size_t) + 1 + b = Buffer(b.getvalue()) + assert read_int(b) == i + +def test_buffer_str_size() -> None: + for s in ("", "a", "a" * 127): + b = Buffer() + write_str(b, s) + assert len(b.getvalue()) == len(s) + 1 + b = Buffer(b.getvalue()) + assert read_str(b) == s [file driver.py] from native import * test_buffer_basic() test_buffer_roundtrip() +test_buffer_int_size() +test_buffer_str_size() def test_buffer_basic_interpreted() -> None: b = Buffer(b"foo") @@ -2782,7 +2814,11 @@ def test_buffer_roundtrip_interpreted() -> None: write_tag(b, 255) write_int(b, 2) write_int(b, 2 ** 85) + write_int(b, 255) write_int(b, -1) + write_int(b, -255) + write_int(b, 1234512344) + write_int(b, 1234512345) b = Buffer(b.getvalue()) assert read_str(b) == "foo" @@ -2797,10 +2833,38 @@ def test_buffer_roundtrip_interpreted() -> None: assert read_tag(b) == 255 assert read_int(b) == 2 assert read_int(b) == 2 ** 85 + assert read_int(b) == 255 assert read_int(b) == -1 + assert read_int(b) == -255 + assert read_int(b) == 1234512344 + assert read_int(b) == 1234512345 + +def test_buffer_int_size_interpreted() -> None: + for i in (-10, -9, 0, 116, 117): + b = Buffer() + write_int(b, i) + assert len(b.getvalue()) == 1 + b = Buffer(b.getvalue()) + assert read_int(b) == i + for i in (-12345, -12344, -11, 118, 12344, 12345): + b = Buffer() + write_int(b, i) + assert len(b.getvalue()) <= 9 # sizeof(size_t) + 1 + b = Buffer(b.getvalue()) + assert read_int(b) == i + +def test_buffer_str_size_interpreted() -> None: + for s in ("", "a", "a" * 127): + b = Buffer() + write_str(b, s) + assert len(b.getvalue()) == len(s) + 1 + b = Buffer(b.getvalue()) + assert read_str(b) == s test_buffer_basic_interpreted() test_buffer_roundtrip_interpreted() +test_buffer_int_size_interpreted() +test_buffer_str_size_interpreted() [case testEnumMethodCalls] from enum import Enum