diff --git a/docs/source/api.rst b/docs/source/api.rst index f7602e03..9b455220 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -62,7 +62,7 @@ JSON :members: encode, encode_lines, encode_into .. autoclass:: Decoder - :members: decode, decode_lines + :members: decode, decode_lines, raw_decode .. autofunction:: encode @@ -80,7 +80,7 @@ MessagePack :members: encode, encode_into .. autoclass:: Decoder - :members: decode + :members: decode, raw_decode .. autoclass:: Ext :members: diff --git a/msgspec/_core.c b/msgspec/_core.c index bfb73680..d19a75b6 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -16128,11 +16128,71 @@ Decoder_decode(Decoder *self, PyObject *const *args, Py_ssize_t nargs) return NULL; } +PyDoc_STRVAR(Decoder_raw_decode__doc__, +"raw_decode(self, buf)\n" +"--\n" +"\n" +"Deserialize an object from MessagePack, allowing trailing data.\n" +"\n" +"Parameters\n" +"----------\n" +"buf : bytes-like\n" +" The message to decode.\n" +"\n" +"Returns\n" +"-------\n" +"obj_and_index : 2-tuple of Any and int\n" +" A tuple containing the deserialized object, as well as the index into\n" +" the input at which the object ended.\n" +); +static PyObject* +Decoder_raw_decode(Decoder *self, PyObject *const *args, Py_ssize_t nargs) +{ + if (!check_positional_nargs(nargs, 1, 1)) { + return NULL; + } + + DecoderState state = { + .type = self->type, + .strict = self->strict, + .dec_hook = self->dec_hook, + .ext_hook = self->ext_hook + }; + + Py_buffer buffer; + buffer.buf = NULL; + if (PyObject_GetBuffer(args[0], &buffer, PyBUF_CONTIG_RO) >= 0) { + state.buffer_obj = args[0]; + state.input_start = buffer.buf; + state.input_pos = buffer.buf; + state.input_end = state.input_pos + buffer.len; + + PyObject *res = mpack_decode(&state, state.type, NULL, false); + + if (res != NULL) { + PyObject *tup = Py_BuildValue( + "(On)", res, + (Py_ssize_t)(state.input_pos - state.input_start) + ); + Py_CLEAR(res); + res = tup; + } + + PyBuffer_Release(&buffer); + return res; + } + return NULL; +} + static struct PyMethodDef Decoder_methods[] = { { "decode", (PyCFunction) Decoder_decode, METH_FASTCALL, Decoder_decode__doc__, }, + { + "raw_decode", (PyCFunction) Decoder_raw_decode, METH_FASTCALL, + Decoder_raw_decode__doc__, + }, {"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS}, {NULL, NULL} /* sentinel */ }; @@ -19174,6 +19234,72 @@ JSONDecoder_decode_lines(JSONDecoder *self, PyObject *const *args, Py_ssize_t na return NULL; } +PyDoc_STRVAR(JSONDecoder_raw_decode__doc__, +"raw_decode(self, buf)\n" +"--\n" +"\n" +"Deserialize an object from JSON, allowing trailing data.\n" +"\n" +"Parameters\n" +"----------\n" +"buf : bytes-like or str\n" +" The message to decode.\n" +"\n" +"Returns\n" +"-------\n" +"obj_and_index : 2-tuple of Any and int\n" +" A tuple containing the deserialized object, as well as the index into\n" +" the input at which the object ended.\n" +); +static PyObject* +JSONDecoder_raw_decode( + JSONDecoder *self, + PyObject *const *args, + Py_ssize_t nargs +) { + if (!check_positional_nargs(nargs, 1, 1)) { + return NULL; + } + + JSONDecoderState state = { + .type = self->type, + .strict = self->strict, + .dec_hook = self->dec_hook, + .float_hook = self->float_hook, + .scratch = NULL, + .scratch_capacity = 0, + .scratch_len = 0 + }; + + Py_buffer buffer; + buffer.buf = NULL; + if (ms_get_buffer(args[0], &buffer) >= 0) { + + state.buffer_obj = args[0]; + state.input_start = buffer.buf; + state.input_pos = buffer.buf; + state.input_end = state.input_pos + buffer.len; + + PyObject *res = json_decode(&state, state.type, NULL); + + if (res != NULL) { + PyObject *tup = Py_BuildValue( + "(On)", res, + (Py_ssize_t)(state.input_pos - state.input_start) + ); + Py_CLEAR(res); + res = tup; + } + + ms_release_buffer(&buffer); + + PyMem_Free(state.scratch); + return res; + } + + return NULL; +} + static struct PyMethodDef JSONDecoder_methods[] = { { "decode", (PyCFunction) JSONDecoder_decode, METH_FASTCALL, @@ -19183,6 +19309,10 @@ static struct PyMethodDef JSONDecoder_methods[] = { "decode_lines", (PyCFunction) JSONDecoder_decode_lines, METH_FASTCALL, JSONDecoder_decode_lines__doc__, }, + { + "raw_decode", (PyCFunction) JSONDecoder_raw_decode, METH_FASTCALL, + JSONDecoder_raw_decode__doc__, + }, {"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS}, {NULL, NULL} /* sentinel */ }; diff --git a/msgspec/json.pyi b/msgspec/json.pyi index 75365d60..19b84192 100644 --- a/msgspec/json.pyi +++ b/msgspec/json.pyi @@ -77,6 +77,7 @@ class Decoder(Generic[T]): ) -> None: ... def decode(self, buf: Union[Buffer, str], /) -> T: ... def decode_lines(self, buf: Union[Buffer, str], /) -> list[T]: ... + def raw_decode(self, buf: Union[Buffer, str], /) -> Tuple[T, int]: ... @overload def decode( diff --git a/msgspec/msgpack.pyi b/msgspec/msgpack.pyi index 1321571d..b3cc84d8 100644 --- a/msgspec/msgpack.pyi +++ b/msgspec/msgpack.pyi @@ -4,6 +4,7 @@ from typing import ( Generic, Literal, Optional, + Tuple, Type, TypeVar, Union, @@ -58,6 +59,7 @@ class Decoder(Generic[T]): ext_hook: ext_hook_sig = None, ) -> None: ... def decode(self, buf: Buffer, /) -> T: ... + def raw_decode(self, buf: Buffer, /) -> Tuple[T, int]: ... class Encoder: enc_hook: enc_hook_sig diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index 9ff5b5c7..a0143e2b 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -640,6 +640,22 @@ def check_msgpack_Decoder_decode_type_comment() -> None: reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ +def check_msgpack_Decoder_raw_decode_any() -> None: + dec = msgspec.msgpack.Decoder() + b = msgspec.msgpack.encode([1, 2, 3]) + o = dec.raw_decode(b) + + reveal_type(o) # assert "tuple" in typ.lower() and "Any" in typ and "int" in typ + + +def check_msgpack_Decoder_raw_decode_typed() -> None: + dec = msgspec.msgpack.Decoder(int) + b = msgspec.msgpack.encode(1) + o = dec.raw_decode(b) + + reveal_type(o) # assert ("Tuple" in typ or "tuple" in typ) and typ.count("int") == 2 + + def check_msgpack_decode_any() -> None: b = msgspec.msgpack.encode([1, 2, 3]) o = msgspec.msgpack.decode(b) @@ -814,6 +830,19 @@ def check_json_Decoder_decode_lines_typed() -> None: reveal_type(o) # assert "list" in typ.lower() and "int" in typ.lower() +def check_json_Decoder_raw_decode_any() -> None: + dec = msgspec.json.Decoder() + o = dec.raw_decode(b'1') + + reveal_type(o) # assert "tuple" in typ.lower() and "any" in typ.lower() and "int" in typ.lower() + + +def check_json_Decoder_raw_decode_typed() -> None: + dec = msgspec.json.Decoder(int) + o = dec.raw_decode(b'1') + reveal_type(o) # assert "tuple" in typ.lower() and typ.lower().count("int") == 2 + + def check_json_decode_any() -> None: b = msgspec.json.encode([1, 2, 3]) o = msgspec.json.decode(b) diff --git a/tests/test_json.py b/tests/test_json.py index 1d3776e5..e3f63150 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -527,6 +527,30 @@ def test_decode_lines_bad_call(self): with pytest.raises(TypeError): dec.decode(1) + def test_raw_decode(self): + dec = msgspec.json.Decoder() + + obj, index = dec.raw_decode(b"[1, 2, 3]trailing invalid") + assert obj == [1, 2, 3] + assert index == len(b"[1, 2, 3]") + + def test_raw_decode_malformed(self): + dec = msgspec.json.Decoder() + with pytest.raises(msgspec.DecodeError, match="malformed"): + dec.raw_decode(b'{"x": efg') + + def test_raw_decode_bad_call(self): + dec = msgspec.json.Decoder() + + with pytest.raises(TypeError): + dec.raw_decode() + + with pytest.raises(TypeError): + dec.raw_decode("{}", 2) + + with pytest.raises(TypeError): + dec.raw_decode(1) + def test_decoder_init_float_hook(self): dec = msgspec.json.Decoder() assert dec.float_hook is None diff --git a/tests/test_msgpack.py b/tests/test_msgpack.py index 55db40f5..e679927f 100644 --- a/tests/test_msgpack.py +++ b/tests/test_msgpack.py @@ -586,6 +586,28 @@ def test_decoding_large_arrays_as_keys_doesnt_preallocate(self): with pytest.raises(msgspec.DecodeError, match="truncated"): msgspec.msgpack.decode(b) + def test_raw_decode(self): + dec = msgspec.msgpack.Decoder() + + msg = msgspec.msgpack.encode([1, 2, 3]) + obj, index = dec.raw_decode(msg + b"trailing") + assert obj == [1, 2, 3] + assert index == len(msg) + + def test_raw_decode_skip_invalid_submessage_raises(self): + """Ensure errors in submessage skipping are raised""" + + class Test(msgspec.Struct): + x: int + + msg = msgspec.msgpack.encode({"x": 1, "y": ["one", "two", "three"]}) + + # Break the message + msg = msg.replace(b"three", b"tree") + + with pytest.raises(msgspec.DecodeError, match="truncated"): + msgspec.msgpack.Decoder(type=Test).raw_decode(msg) + class TestTypedDecoder: def check_unexpected_type(self, dec_type, val, msg):