Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ def __eq__(self, other: object) -> bool:
def serialize(self) -> JsonDict | str:
assert self.type is not None
type_ref = self.type.fullname
if not self.args and not self.last_known_value:
if not self.args and not self.last_known_value and not self.extra_attrs:
return type_ref
data: JsonDict = {
".class": "Instance",
Expand Down Expand Up @@ -1745,7 +1745,6 @@ def copy_modified(
),
extra_attrs=self.extra_attrs,
)
# We intentionally don't copy the extra_attrs here, so they will be erased.
new.can_be_true = self.can_be_true
new.can_be_false = self.can_be_false
return new
Expand Down
133 changes: 108 additions & 25 deletions mypyc/lib-rt/native_internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
#define START_SIZE 512
#define MAX_SHORT_INT_TAGGED (255 << 1)

#define MAX_SHORT_LEN 127
#define LONG_STR_TAG 1

#define MIN_SHORT_INT -10
#define MAX_SHORT_INT 117
#define MEDIUM_INT_TAG 1
#define LONG_INT_TAG 3

typedef struct {
PyObject_HEAD
Py_ssize_t pos;
Expand Down Expand Up @@ -166,6 +174,12 @@ _check_read(BufferObject *data, Py_ssize_t need) {
return 1;
}

/*
bool format: single byte
\x00 - False
\x01 - True
*/

static char
read_bool_internal(PyObject *data) {
if (_check_buffer(data) == 2)
Expand Down Expand Up @@ -225,20 +239,34 @@ write_bool(PyObject *self, PyObject *args, PyObject *kwds) {
return Py_None;
}

/*
str format: size followed by UTF-8 bytes
short strings (len <= 127): single byte for size as `(uint8_t)size << 1`
long strings: \x01 followed by size as Py_ssize_t
*/

static PyObject*
read_str_internal(PyObject *data) {
if (_check_buffer(data) == 2)
return NULL;

if (_check_read((BufferObject *)data, sizeof(Py_ssize_t)) == 2)
return NULL;
Py_ssize_t size;
char *buf = ((BufferObject *)data)->buf;
// Read string length.
Py_ssize_t size = *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos);
((BufferObject *)data)->pos += sizeof(Py_ssize_t);
if (_check_read((BufferObject *)data, size) == 2)
if (_check_read((BufferObject *)data, 1) == 2)
return NULL;
uint8_t first = *(uint8_t *)(buf + ((BufferObject *)data)->pos);
((BufferObject *)data)->pos += 1;
if (first != LONG_STR_TAG) {
// Common case: short string (len <= 127).
size = (Py_ssize_t)(first >> 1);
} else {
size = *(Py_ssize_t *)(buf + ((BufferObject *)data)->pos);
((BufferObject *)data)->pos += sizeof(Py_ssize_t);
}
// Read string content.
if (_check_read((BufferObject *)data, size) == 2)
return NULL;
PyObject *res = PyUnicode_FromStringAndSize(
buf + ((BufferObject *)data)->pos, (Py_ssize_t)size
);
Expand Down Expand Up @@ -266,14 +294,28 @@ write_str_internal(PyObject *data, PyObject *value) {
const char *chunk = PyUnicode_AsUTF8AndSize(value, &size);
if (!chunk)
return 2;
Py_ssize_t need = size + sizeof(Py_ssize_t);
if (_check_size((BufferObject *)data, need) == 2)
return 2;

char *buf = ((BufferObject *)data)->buf;
Py_ssize_t need;
char *buf;
// Write string length.
*(Py_ssize_t *)(buf + ((BufferObject *)data)->pos) = size;
((BufferObject *)data)->pos += sizeof(Py_ssize_t);
if (size <= MAX_SHORT_LEN) {
// Common case: short string (len <= 127) store as single byte.
need = size + 1;
if (_check_size((BufferObject *)data, need) == 2)
return 2;
buf = ((BufferObject *)data)->buf;
*(uint8_t *)(buf + ((BufferObject *)data)->pos) = (uint8_t)size << 1;
((BufferObject *)data)->pos += 1;
} else {
need = size + sizeof(Py_ssize_t) + 1;
if (_check_size((BufferObject *)data, need) == 2)
return 2;
buf = ((BufferObject *)data)->buf;
*(uint8_t *)(buf + ((BufferObject *)data)->pos) = LONG_STR_TAG;
((BufferObject *)data)->pos += 1;
*(Py_ssize_t *)(buf + ((BufferObject *)data)->pos) = size;
((BufferObject *)data)->pos += sizeof(Py_ssize_t);
}
// Write string content.
memcpy(buf + ((BufferObject *)data)->pos, chunk, size);
((BufferObject *)data)->pos += size;
Expand All @@ -299,6 +341,11 @@ write_str(PyObject *self, PyObject *args, PyObject *kwds) {
return Py_None;
}

/*
float format:
stored as a C double
*/

static double
read_float_internal(PyObject *data) {
if (_check_buffer(data) == 2)
Expand Down Expand Up @@ -357,19 +404,33 @@ write_float(PyObject *self, PyObject *args, PyObject *kwds) {
return Py_None;
}

/*
int format:
most common values (-10 <= value <= 117): single byte as `(uint8_t)(value + 10) << 1`
medium values (fit in CPyTagged): \x01 followed by CPyTagged value
long values (very rare): \x03 followed by decimal string (see str format)
*/

static CPyTagged
read_int_internal(PyObject *data) {
if (_check_buffer(data) == 2)
return CPY_INT_TAG;

if (_check_read((BufferObject *)data, sizeof(CPyTagged)) == 2)
return CPY_INT_TAG;
char *buf = ((BufferObject *)data)->buf;
if (_check_read((BufferObject *)data, 1) == 2)
return CPY_INT_TAG;

CPyTagged ret = *(CPyTagged *)(buf + ((BufferObject *)data)->pos);
((BufferObject *)data)->pos += sizeof(CPyTagged);
if ((ret & CPY_INT_TAG) == 0)
uint8_t first = *(uint8_t *)(buf + ((BufferObject *)data)->pos);
((BufferObject *)data)->pos += 1;
if ((first & MEDIUM_INT_TAG) == 0) {
// Most common case: int that is small in absolute value.
return ((Py_ssize_t)(first >> 1) + MIN_SHORT_INT) << 1;
}
if (first == MEDIUM_INT_TAG) {
CPyTagged ret = *(CPyTagged *)(buf + ((BufferObject *)data)->pos);
((BufferObject *)data)->pos += sizeof(CPyTagged);
return ret;
}
// People who have literal ints not fitting in size_t should be punished :-)
PyObject *str_ret = read_str_internal(data);
if (str_ret == NULL)
Expand Down Expand Up @@ -397,17 +458,34 @@ write_int_internal(PyObject *data, CPyTagged value) {
if (_check_buffer(data) == 2)
return 2;

if (_check_size((BufferObject *)data, sizeof(CPyTagged)) == 2)
return 2;
char *buf = ((BufferObject *)data)->buf;
char *buf;
if ((value & CPY_INT_TAG) == 0) {
*(CPyTagged *)(buf + ((BufferObject *)data)->pos) = value;
Py_ssize_t real_value = CPyTagged_ShortAsSsize_t(value);
if (real_value >= MIN_SHORT_INT && real_value <= MAX_SHORT_INT) {
// Most common case: int that is small in absolute value.
if (_check_size((BufferObject *)data, 1) == 2)
return 2;
buf = ((BufferObject *)data)->buf;
*(uint8_t *)(buf + ((BufferObject *)data)->pos) = (uint8_t)(real_value - MIN_SHORT_INT) << 1;
((BufferObject *)data)->pos += 1;
((BufferObject *)data)->end += 1;
} else {
if (_check_size((BufferObject *)data, sizeof(CPyTagged) + 1) == 2)
return 2;
buf = ((BufferObject *)data)->buf;
*(uint8_t *)(buf + ((BufferObject *)data)->pos) = MEDIUM_INT_TAG;
((BufferObject *)data)->pos += 1;
*(CPyTagged *)(buf + ((BufferObject *)data)->pos) = value;
((BufferObject *)data)->pos += sizeof(CPyTagged);
((BufferObject *)data)->end += sizeof(CPyTagged) + 1;
}
} else {
*(CPyTagged *)(buf + ((BufferObject *)data)->pos) = CPY_INT_TAG;
}
((BufferObject *)data)->pos += sizeof(CPyTagged);
((BufferObject *)data)->end += sizeof(CPyTagged);
if ((value & CPY_INT_TAG) != 0) {
if (_check_size((BufferObject *)data, 1) == 2)
return 2;
buf = ((BufferObject *)data)->buf;
*(uint8_t *)(buf + ((BufferObject *)data)->pos) = LONG_INT_TAG;
((BufferObject *)data)->pos += 1;
((BufferObject *)data)->end += 1;
PyObject *str_value = PyObject_Str(CPyTagged_LongAsObject(value));
if (str_value == NULL)
return 2;
Expand Down Expand Up @@ -438,6 +516,11 @@ write_int(PyObject *self, PyObject *args, PyObject *kwds) {
return Py_None;
}

/*
integer tag format (0 <= t <= 255):
stored as a uint8_t
*/

static uint8_t
read_tag_internal(PyObject *data) {
if (_check_buffer(data) == 2)
Expand Down
2 changes: 1 addition & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@
arg_types=[object_rprimitive, uint8_rprimitive],
return_type=none_rprimitive,
c_function_name="write_tag_internal",
error_kind=ERR_MAGIC_OVERLAPPING,
error_kind=ERR_MAGIC,
)

function_op(
Expand Down
64 changes: 64 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,11 @@ def test_buffer_roundtrip() -> None:
write_tag(b, TAG_B)
write_int(b, 2)
write_int(b, 2 ** 85)
write_int(b, 255)
write_int(b, -1)
write_int(b, -255)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test also the edge cases (-11, -10, -9, 116, 117, 118). Test a few more different lengths of integers (e.g. 15 bits, 23 bits, 30 bits) with arbitrary lower bits.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think it would also make sense to test something like len(data.getvalue()) == 1 etc.

write_int(b, 1234512344)
write_int(b, 1234512345)

b = Buffer(b.getvalue())
assert read_str(b) == "foo"
Expand All @@ -2756,13 +2760,41 @@ def test_buffer_roundtrip() -> None:
assert read_tag(b) == TAG_B
assert read_int(b) == 2
assert read_int(b) == 2 ** 85
assert read_int(b) == 255
assert read_int(b) == -1
assert read_int(b) == -255
assert read_int(b) == 1234512344
assert read_int(b) == 1234512345

def test_buffer_int_size() -> None:
for i in (-10, -9, 0, 116, 117):
b = Buffer()
write_int(b, i)
assert len(b.getvalue()) == 1
b = Buffer(b.getvalue())
assert read_int(b) == i
for i in (-12345, -12344, -11, 118, 12344, 12345):
b = Buffer()
write_int(b, i)
assert len(b.getvalue()) <= 9 # sizeof(size_t) + 1
b = Buffer(b.getvalue())
assert read_int(b) == i

def test_buffer_str_size() -> None:
for s in ("", "a", "a" * 127):
b = Buffer()
write_str(b, s)
assert len(b.getvalue()) == len(s) + 1
b = Buffer(b.getvalue())
assert read_str(b) == s

[file driver.py]
from native import *

test_buffer_basic()
test_buffer_roundtrip()
test_buffer_int_size()
test_buffer_str_size()

def test_buffer_basic_interpreted() -> None:
b = Buffer(b"foo")
Expand All @@ -2782,7 +2814,11 @@ def test_buffer_roundtrip_interpreted() -> None:
write_tag(b, 255)
write_int(b, 2)
write_int(b, 2 ** 85)
write_int(b, 255)
write_int(b, -1)
write_int(b, -255)
write_int(b, 1234512344)
write_int(b, 1234512345)

b = Buffer(b.getvalue())
assert read_str(b) == "foo"
Expand All @@ -2797,10 +2833,38 @@ def test_buffer_roundtrip_interpreted() -> None:
assert read_tag(b) == 255
assert read_int(b) == 2
assert read_int(b) == 2 ** 85
assert read_int(b) == 255
assert read_int(b) == -1
assert read_int(b) == -255
assert read_int(b) == 1234512344
assert read_int(b) == 1234512345

def test_buffer_int_size_interpreted() -> None:
for i in (-10, -9, 0, 116, 117):
b = Buffer()
write_int(b, i)
assert len(b.getvalue()) == 1
b = Buffer(b.getvalue())
assert read_int(b) == i
for i in (-12345, -12344, -11, 118, 12344, 12345):
b = Buffer()
write_int(b, i)
assert len(b.getvalue()) <= 9 # sizeof(size_t) + 1
b = Buffer(b.getvalue())
assert read_int(b) == i

def test_buffer_str_size_interpreted() -> None:
for s in ("", "a", "a" * 127):
b = Buffer()
write_str(b, s)
assert len(b.getvalue()) == len(s) + 1
b = Buffer(b.getvalue())
assert read_str(b) == s

test_buffer_basic_interpreted()
test_buffer_roundtrip_interpreted()
test_buffer_int_size_interpreted()
test_buffer_str_size_interpreted()

[case testEnumMethodCalls]
from enum import Enum
Expand Down
Loading