diff --git a/mypy/typeshed/stubs/librt/librt/internal.pyi b/mypy/typeshed/stubs/librt/librt/internal.pyi index a47a4849fe20..8a5fc262931e 100644 --- a/mypy/typeshed/stubs/librt/librt/internal.pyi +++ b/mypy/typeshed/stubs/librt/librt/internal.pyi @@ -8,6 +8,8 @@ def write_bool(data: Buffer, value: bool) -> None: ... def read_bool(data: Buffer) -> bool: ... def write_str(data: Buffer, value: str) -> None: ... def read_str(data: Buffer) -> str: ... +def write_bytes(data: Buffer, value: bytes) -> None: ... +def read_bytes(data: Buffer) -> bytes: ... def write_float(data: Buffer, value: float) -> None: ... def read_float(data: Buffer) -> float: ... def write_int(data: Buffer, value: int) -> None: ... diff --git a/mypyc/lib-rt/librt_internal.c b/mypyc/lib-rt/librt_internal.c index b97d6665b515..6f6a110446ad 100644 --- a/mypyc/lib-rt/librt_internal.c +++ b/mypyc/lib-rt/librt_internal.c @@ -346,6 +346,100 @@ write_str(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames return Py_None; } +/* +bytes format: size followed by bytes + short bytes (len <= 127): single byte for size as `(uint8_t)size << 1` + long bytes: \x01 followed by size as Py_ssize_t +*/ + +static PyObject* +read_bytes_internal(PyObject *data) { + _CHECK_BUFFER(data, NULL) + + // Read length. + Py_ssize_t size; + _CHECK_READ(data, 1, NULL) + uint8_t first = _READ(data, uint8_t) + if (likely(first != LONG_STR_TAG)) { + // Common case: short bytes (len <= 127). + size = (Py_ssize_t)(first >> 1); + } else { + _CHECK_READ(data, sizeof(CPyTagged), NULL) + size = _READ(data, Py_ssize_t) + } + // Read bytes content. + char *buf = ((BufferObject *)data)->buf; + _CHECK_READ(data, size, NULL) + PyObject *res = PyBytes_FromStringAndSize( + buf + ((BufferObject *)data)->pos, (Py_ssize_t)size + ); + if (unlikely(res == NULL)) + return NULL; + ((BufferObject *)data)->pos += size; + return res; +} + +static PyObject* +read_bytes(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) { + static const char * const kwlist[] = {"data", 0}; + static CPyArg_Parser parser = {"O:read_bytes", kwlist, 0}; + PyObject *data; + if (unlikely(!CPyArg_ParseStackAndKeywordsOneArg(args, nargs, kwnames, &parser, &data))) { + return NULL; + } + return read_bytes_internal(data); +} + +static char +write_bytes_internal(PyObject *data, PyObject *value) { + _CHECK_BUFFER(data, CPY_NONE_ERROR) + + const char *chunk = PyBytes_AsString(value); + if (unlikely(chunk == NULL)) + return CPY_NONE_ERROR; + Py_ssize_t size = PyBytes_GET_SIZE(value); + + Py_ssize_t need; + // Write length. + if (likely(size <= MAX_SHORT_LEN)) { + // Common case: short bytes (len <= 127) store as single byte. + need = size + 1; + _CHECK_SIZE(data, need) + _WRITE(data, uint8_t, (uint8_t)size << 1) + } else { + need = size + sizeof(Py_ssize_t) + 1; + _CHECK_SIZE(data, need) + _WRITE(data, uint8_t, LONG_STR_TAG) + _WRITE(data, Py_ssize_t, size) + } + // Write bytes content. + char *buf = ((BufferObject *)data)->buf; + memcpy(buf + ((BufferObject *)data)->pos, chunk, size); + ((BufferObject *)data)->pos += size; + ((BufferObject *)data)->end += need; + return CPY_NONE; +} + +static PyObject* +write_bytes(PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames) { + static const char * const kwlist[] = {"data", "value", 0}; + static CPyArg_Parser parser = {"OO:write_bytes", kwlist, 0}; + PyObject *data; + PyObject *value; + if (unlikely(!CPyArg_ParseStackAndKeywordsSimple(args, nargs, kwnames, &parser, &data, &value))) { + return NULL; + } + if (unlikely(!PyBytes_Check(value))) { + PyErr_SetString(PyExc_TypeError, "value must be a bytes object"); + return NULL; + } + if (unlikely(write_bytes_internal(data, value) == CPY_NONE_ERROR)) { + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + /* float format: stored as a C double @@ -565,6 +659,8 @@ static PyMethodDef librt_internal_module_methods[] = { {"read_bool", (PyCFunction)read_bool, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("read a bool")}, {"write_str", (PyCFunction)write_str, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("write a string")}, {"read_str", (PyCFunction)read_str, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("read a string")}, + {"write_bytes", (PyCFunction)write_bytes, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("write bytes")}, + {"read_bytes", (PyCFunction)read_bytes, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("read bytes")}, {"write_float", (PyCFunction)write_float, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("write a float")}, {"read_float", (PyCFunction)read_float, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("read a float")}, {"write_int", (PyCFunction)write_int, METH_FASTCALL | METH_KEYWORDS, PyDoc_STR("write an int")}, @@ -590,7 +686,7 @@ librt_internal_module_exec(PyObject *m) } // Export mypy internal C API, be careful with the order! - static void *NativeInternal_API[14] = { + static void *NativeInternal_API[16] = { (void *)Buffer_internal, (void *)Buffer_internal_empty, (void *)Buffer_getvalue_internal, @@ -605,6 +701,8 @@ librt_internal_module_exec(PyObject *m) (void *)write_tag_internal, (void *)read_tag_internal, (void *)NativeInternal_ABI_Version, + (void *)write_bytes_internal, + (void *)read_bytes_internal, }; PyObject *c_api_object = PyCapsule_New((void *)NativeInternal_API, "librt.internal._C_API", NULL); if (PyModule_Add(m, "_C_API", c_api_object) < 0) { diff --git a/mypyc/lib-rt/librt_internal.h b/mypyc/lib-rt/librt_internal.h index fd8ec2422cc5..d996b8fd95c1 100644 --- a/mypyc/lib-rt/librt_internal.h +++ b/mypyc/lib-rt/librt_internal.h @@ -19,6 +19,8 @@ static CPyTagged read_int_internal(PyObject *data); static char write_tag_internal(PyObject *data, uint8_t value); static uint8_t read_tag_internal(PyObject *data); static int NativeInternal_ABI_Version(void); +static char write_bytes_internal(PyObject *data, PyObject *value); +static PyObject *read_bytes_internal(PyObject *data); #else @@ -38,6 +40,8 @@ static void **NativeInternal_API; #define write_tag_internal (*(char (*)(PyObject *source, uint8_t value)) NativeInternal_API[11]) #define read_tag_internal (*(uint8_t (*)(PyObject *source)) NativeInternal_API[12]) #define NativeInternal_ABI_Version (*(int (*)(void)) NativeInternal_API[13]) +#define write_bytes_internal (*(char (*)(PyObject *source, PyObject *value)) NativeInternal_API[14]) +#define read_bytes_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[15]) static int import_librt_internal(void) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 18d475fe89d4..c12172875e8b 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -393,6 +393,22 @@ error_kind=ERR_MAGIC, ) +function_op( + name="librt.internal.write_bytes", + arg_types=[object_rprimitive, bytes_rprimitive], + return_type=none_rprimitive, + c_function_name="write_bytes_internal", + error_kind=ERR_MAGIC, +) + +function_op( + name="librt.internal.read_bytes", + arg_types=[object_rprimitive], + return_type=bytes_rprimitive, + c_function_name="read_bytes_internal", + error_kind=ERR_MAGIC, +) + function_op( name="librt.internal.write_float", arg_types=[object_rprimitive, float_rprimitive], diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 3280b21cf7e6..27ffba45ba39 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1454,7 +1454,7 @@ from typing import Final from mypy_extensions import u8 from librt.internal import ( Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, - write_int, read_int, write_tag, read_tag + write_int, read_int, write_tag, read_tag, write_bytes, read_bytes, ) Tag = u8 @@ -1463,6 +1463,7 @@ TAG: Final[Tag] = 1 def foo() -> None: b = Buffer() write_str(b, "foo") + write_bytes(b, b"bar") write_bool(b, True) write_float(b, 0.1) write_int(b, 1) @@ -1470,6 +1471,7 @@ def foo() -> None: b = Buffer(b.getvalue()) x = read_str(b) + xb = read_bytes(b) y = read_bool(b) z = read_float(b) t = read_int(b) @@ -1478,36 +1480,43 @@ def foo() -> None: def foo(): r0, b :: librt.internal.Buffer r1 :: str - r2, r3, r4, r5, r6 :: None - r7 :: bytes - r8 :: librt.internal.Buffer - r9, x :: str - r10, y :: bool - r11, z :: float - r12, t :: int - r13, u :: u8 + r2 :: None + r3 :: bytes + r4, r5, r6, r7, r8 :: None + r9 :: bytes + r10 :: librt.internal.Buffer + r11, x :: str + r12, xb :: bytes + r13, y :: bool + r14, z :: float + r15, t :: int + r16, u :: u8 L0: r0 = Buffer_internal_empty() b = r0 r1 = 'foo' r2 = write_str_internal(b, r1) - r3 = write_bool_internal(b, 1) - r4 = write_float_internal(b, 0.1) - r5 = write_int_internal(b, 2) - r6 = write_tag_internal(b, 1) - r7 = Buffer_getvalue_internal(b) - r8 = Buffer_internal(r7) - b = r8 - r9 = read_str_internal(b) - x = r9 - r10 = read_bool_internal(b) - y = r10 - r11 = read_float_internal(b) - z = r11 - r12 = read_int_internal(b) - t = r12 - r13 = read_tag_internal(b) - u = r13 + r3 = b'bar' + r4 = write_bytes_internal(b, r3) + r5 = write_bool_internal(b, 1) + r6 = write_float_internal(b, 0.1) + r7 = write_int_internal(b, 2) + r8 = write_tag_internal(b, 1) + r9 = Buffer_getvalue_internal(b) + r10 = Buffer_internal(r9) + b = r10 + r11 = read_str_internal(b) + x = r11 + r12 = read_bytes_internal(b) + xb = r12 + r13 = read_bool_internal(b) + y = r13 + r14 = read_float_internal(b) + z = r14 + r15 = read_int_internal(b) + t = r15 + r16 = read_tag_internal(b) + u = r16 return 1 [case testEnumFastPath] diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 84704ce66c81..efa6c225ecab 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2715,7 +2715,7 @@ from typing import Final from mypy_extensions import u8 from librt.internal import ( Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float, - write_int, read_int, write_tag, read_tag + write_int, read_int, write_tag, read_tag, write_bytes, read_bytes ) Tag = u8 @@ -2733,6 +2733,11 @@ def test_buffer_roundtrip() -> None: write_bool(b, True) write_str(b, "bar" * 1000) write_bool(b, False) + write_bytes(b, b"bar") + write_bytes(b, b"bar" * 100) + write_bytes(b, b"") + write_bytes(b, b"a" * 127) + write_bytes(b, b"a" * 128) write_float(b, 0.1) write_int(b, 0) write_int(b, 1) @@ -2752,6 +2757,11 @@ def test_buffer_roundtrip() -> None: assert read_bool(b) is True assert read_str(b) == "bar" * 1000 assert read_bool(b) is False + assert read_bytes(b) == b"bar" + assert read_bytes(b) == b"bar" * 100 + assert read_bytes(b) == b"" + assert read_bytes(b) == b"a" * 127 + assert read_bytes(b) == b"a" * 128 assert read_float(b) == 0.1 assert read_int(b) == 0 assert read_int(b) == 1 @@ -2806,6 +2816,11 @@ def test_buffer_roundtrip_interpreted() -> None: write_bool(b, True) write_str(b, "bar" * 1000) write_bool(b, False) + write_bytes(b, b"bar") + write_bytes(b, b"bar" * 100) + write_bytes(b, b"") + write_bytes(b, b"a" * 127) + write_bytes(b, b"a" * 128) write_float(b, 0.1) write_int(b, 0) write_int(b, 1) @@ -2825,6 +2840,11 @@ def test_buffer_roundtrip_interpreted() -> None: assert read_bool(b) is True assert read_str(b) == "bar" * 1000 assert read_bool(b) is False + assert read_bytes(b) == b"bar" + assert read_bytes(b) == b"bar" * 100 + assert read_bytes(b) == b"" + assert read_bytes(b) == b"a" * 127 + assert read_bytes(b) == b"a" * 128 assert read_float(b) == 0.1 assert read_int(b) == 0 assert read_int(b) == 1