diff --git a/docs/api/jupyter_client.asynchronous.rst b/docs/api/jupyter_client.asynchronous.rst index 5fee3b81..7377df34 100644 --- a/docs/api/jupyter_client.asynchronous.rst +++ b/docs/api/jupyter_client.asynchronous.rst @@ -7,13 +7,13 @@ Submodules .. automodule:: jupyter_client.asynchronous.client :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: jupyter_client.asynchronous :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/api/jupyter_client.blocking.rst b/docs/api/jupyter_client.blocking.rst index ba194cbf..e74eb840 100644 --- a/docs/api/jupyter_client.blocking.rst +++ b/docs/api/jupyter_client.blocking.rst @@ -7,13 +7,13 @@ Submodules .. automodule:: jupyter_client.blocking.client :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: jupyter_client.blocking :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/api/jupyter_client.ioloop.rst b/docs/api/jupyter_client.ioloop.rst index 2c06a832..f1d9d909 100644 --- a/docs/api/jupyter_client.ioloop.rst +++ b/docs/api/jupyter_client.ioloop.rst @@ -7,19 +7,19 @@ Submodules .. automodule:: jupyter_client.ioloop.manager :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.ioloop.restarter :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: jupyter_client.ioloop :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/api/jupyter_client.provisioning.rst b/docs/api/jupyter_client.provisioning.rst index 5a2de00f..80c431a7 100644 --- a/docs/api/jupyter_client.provisioning.rst +++ b/docs/api/jupyter_client.provisioning.rst @@ -7,25 +7,25 @@ Submodules .. automodule:: jupyter_client.provisioning.factory :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.provisioning.local_provisioner :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.provisioning.provisioner_base :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: jupyter_client.provisioning :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/api/jupyter_client.rst b/docs/api/jupyter_client.rst index 336dcc2d..797ceae1 100644 --- a/docs/api/jupyter_client.rst +++ b/docs/api/jupyter_client.rst @@ -19,139 +19,139 @@ Submodules .. automodule:: jupyter_client.adapter :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.channels :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.channelsabc :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.client :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.clientabc :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.connect :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.consoleapp :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.jsonutil :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.kernelapp :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.kernelspec :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.kernelspecapp :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.launcher :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.localinterfaces :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.manager :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.managerabc :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.multikernelmanager :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.restarter :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.runapp :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.session :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.threaded :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.utils :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.win_interrupt :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: jupyter_client :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/docs/api/jupyter_client.ssh.rst b/docs/api/jupyter_client.ssh.rst index 06a1d9db..9ad3c79b 100644 --- a/docs/api/jupyter_client.ssh.rst +++ b/docs/api/jupyter_client.ssh.rst @@ -7,19 +7,19 @@ Submodules .. automodule:: jupyter_client.ssh.forward :members: - :undoc-members: :show-inheritance: + :undoc-members: .. automodule:: jupyter_client.ssh.tunnel :members: - :undoc-members: :show-inheritance: + :undoc-members: Module contents --------------- .. automodule:: jupyter_client.ssh :members: - :undoc-members: :show-inheritance: + :undoc-members: diff --git a/jupyter_client/session.py b/jupyter_client/session.py index c387cd06..799b08d6 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -12,6 +12,7 @@ # Distributed under the terms of the Modified BSD License. from __future__ import annotations +import functools import hashlib import hmac import json @@ -32,6 +33,7 @@ from traitlets import ( Any, Bool, + Callable, CBytes, CUnicode, Dict, @@ -124,6 +126,30 @@ def json_unpacker(s: str | bytes) -> t.Any: return json.loads(s) +try: + import orjson # type:ignore[import-not-found] +except ModuleNotFoundError: + orjson = None + _default_packer_unpacker = "json", "json" + _default_pack_unpack = (json_packer, json_unpacker) +else: + orjson_packer = functools.partial( + orjson.dumps, default=json_default, option=orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z + ) + orjson_unpacker = orjson.loads + _default_packer_unpacker = "orjson", "orjson" + _default_pack_unpack = (orjson_packer, orjson_unpacker) + +try: + import msgpack # type:ignore[import-not-found] + +except ModuleNotFoundError: + msgpack = None +else: + msgpack_packer = functools.partial(msgpack.packb, default=json_default) + msgpack_unpacker = msgpack.unpackb + + def pickle_packer(o: t.Any) -> bytes: """Pack an object using the pickle module.""" return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) @@ -131,8 +157,6 @@ def pickle_packer(o: t.Any) -> bytes: pickle_unpacker = pickle.loads -default_packer = json_packer -default_unpacker = json_unpacker DELIM = b"" # singleton dummy tracker, which will always report as done @@ -315,7 +339,7 @@ class Session(Configurable): debug : bool whether to trigger extra debugging statements - packer/unpacker : str : 'json', 'pickle' or import_string + packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. @@ -350,48 +374,40 @@ class Session(Configurable): """, ) + # serialization traits: packer = DottedObjectName( - "json", + _default_packer_unpacker[0], config=True, help="""The name of the packer for serializing messages. Should be one of 'json', 'pickle', or an import name for a custom callable serializer.""", ) - - @observe("packer") - def _packer_changed(self, change: t.Any) -> None: - new = change["new"] - if new.lower() == "json": - self.pack = json_packer - self.unpack = json_unpacker - self.unpacker = new - elif new.lower() == "pickle": - self.pack = pickle_packer - self.unpack = pickle_unpacker - self.unpacker = new - else: - self.pack = import_item(str(new)) - unpacker = DottedObjectName( - "json", + _default_packer_unpacker[1], config=True, help="""The name of the unpacker for unserializing messages. Only used with custom functions for `packer`.""", ) - - @observe("unpacker") - def _unpacker_changed(self, change: t.Any) -> None: - new = change["new"] - if new.lower() == "json": - self.pack = json_packer - self.unpack = json_unpacker - self.packer = new - elif new.lower() == "pickle": - self.pack = pickle_packer - self.unpack = pickle_unpacker - self.packer = new + pack = Callable(_default_pack_unpack[0]) # the actual packer function + unpack = Callable(_default_pack_unpack[1]) # the actual unpacker function + + @observe("packer", "unpacker") + def _packer_unpacker_changed(self, change: t.Any) -> None: + new = change["new"].lower() + if new == "orjson" and orjson: + self.pack, self.unpack = orjson_packer, orjson_unpacker + elif new == "json" or new == "orjson": + self.pack, self.unpack = json_packer, json_unpacker + elif new == "pickle": + self.pack, self.unpack = pickle_packer, pickle_unpacker + elif new == "msgpack" and msgpack: + self.pack, self.unpack = msgpack_packer, msgpack_unpacker else: - self.unpack = import_item(str(new)) + obj = import_item(str(change["new"])) + name = "pack" if change["name"] == "packer" else "unpack" + self.set_trait(name, obj) + return + self.packer = self.unpacker = change["new"] session = CUnicode("", config=True, help="""The UUID identifying this session.""") @@ -416,8 +432,7 @@ def _session_changed(self, change: t.Any) -> None: metadata = Dict( {}, config=True, - help="Metadata dictionary, which serves as the default top-level metadata dict for each " - "message.", + help="Metadata dictionary, which serves as the default top-level metadata dict for each message.", ) # if 0, no adapting to do. @@ -486,25 +501,6 @@ def _keyfile_changed(self, change: t.Any) -> None: # for protecting against sends from forks pid = Integer() - # serialization traits: - - pack = Any(default_packer) # the actual packer function - - @observe("pack") - def _pack_changed(self, change: t.Any) -> None: - new = change["new"] - if not callable(new): - raise TypeError("packer must be callable, not %s" % type(new)) - - unpack = Any(default_unpacker) # the actual packer function - - @observe("unpack") - def _unpack_changed(self, change: t.Any) -> None: - # unpacker is not checked - it is assumed to be - new = change["new"] - if not callable(new): - raise TypeError("unpacker must be callable, not %s" % type(new)) - # thresholds: copy_threshold = Integer( 2**16, @@ -514,8 +510,7 @@ def _unpack_changed(self, change: t.Any) -> None: buffer_threshold = Integer( MAX_BYTES, config=True, - help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid " - "pickling.", + help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.", ) item_threshold = Integer( MAX_ITEMS, @@ -533,7 +528,7 @@ def __init__(self, **kwargs: t.Any) -> None: debug : bool whether to trigger extra debugging statements - packer/unpacker : str : 'json', 'pickle' or import_string + packer/unpacker : str : 'orjson', 'json', 'pickle', 'msgpack' or import_string importstrings for methods to serialize message parts. If just 'json' or 'pickle', predefined JSON and pickle packers will be used. Otherwise, the entire importstring must be used. @@ -625,10 +620,7 @@ def _check_packers(self) -> None: unpacked = unpack(packed) assert unpacked == msg_list except Exception as e: - msg = ( - f"unpacker '{self.unpacker}' could not handle output from packer" - f" '{self.packer}': {e}" - ) + msg = f"unpacker '{self.unpacker}' could not handle output from packer '{self.packer}': {e}" raise ValueError(msg) from e # check datetime support diff --git a/pyproject.toml b/pyproject.toml index 27503497..36f3d281 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ requires-python = ">=3.9" dependencies = [ "importlib_metadata>=4.8.3;python_version<\"3.10\"", "jupyter_core>=4.12,!=5.0.*", + "orjson>=3.10.18; implementation_name == 'cpython'", "python-dateutil>=2.8.2", "pyzmq>=23.0", "tornado>=6.2", diff --git a/tests/test_session.py b/tests/test_session.py index f30f44f4..88117d05 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -8,12 +8,14 @@ import uuid import warnings from datetime import datetime +from pickle import PicklingError from unittest import mock import pytest import zmq from dateutil.tz import tzlocal from tornado import ioloop +from traitlets import TraitError from zmq.eventloop.zmqstream import ZMQStream from jupyter_client import jsonutil @@ -40,6 +42,16 @@ def session(): return ss.Session() +serializers = [ + ("json", ss.json_packer, ss.json_unpacker), + ("pickle", ss.pickle_packer, ss.pickle_unpacker), +] +if ss.orjson: + serializers.append(("orjson", ss.orjson_packer, ss.orjson_unpacker)) +if ss.msgpack: + serializers.append(("msgpack", ss.msgpack_packer, ss.msgpack_unpacker)) + + @pytest.mark.usefixtures("no_copy_threshold") class TestSession: def assertEqual(self, a, b): @@ -63,7 +75,11 @@ def test_msg(self, session): self.assertEqual(msg["header"]["msg_type"], "execute") self.assertEqual(msg["msg_type"], "execute") - def test_serialize(self, session): + @pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) + def test_serialize(self, session, packer, pack, unpack): + session.packer = packer + assert session.pack is pack + assert session.unpack is unpack msg = session.msg("execute", content=dict(a=10, b=1.1)) msg_list = session.serialize(msg, ident=b"foo") ident, msg_list = session.feed_identities(msg_list) @@ -233,16 +249,16 @@ async def test_send(self, session): def test_args(self, session): """initialization arguments for Session""" s = session - self.assertTrue(s.pack is ss.default_packer) - self.assertTrue(s.unpack is ss.default_unpacker) + assert s.pack is ss._default_pack_unpack[0] + assert s.unpack is ss._default_pack_unpack[1] self.assertEqual(s.username, os.environ.get("USER", "username")) s = ss.Session() self.assertEqual(s.username, os.environ.get("USER", "username")) - with pytest.raises(TypeError): + with pytest.raises(TraitError): ss.Session(pack="hi") - with pytest.raises(TypeError): + with pytest.raises(TraitError): ss.Session(unpack="hi") u = str(uuid.uuid4()) s = ss.Session(username="carrot", session=u) @@ -490,11 +506,6 @@ async def test_send_raw(self, session): B.close() ctx.term() - def test_set_packer(self, session): - s = session - s.packer = "json" - s.unpacker = "json" - def test_clone(self, session): s = session s._add_digest("initial") @@ -514,14 +525,45 @@ def test_squash_unicode(): assert ss.squash_unicode("hi") == b"hi" -def test_json_packer(): - ss.json_packer(dict(a=1)) - with pytest.raises(ValueError): - ss.json_packer(dict(a=ss.Session())) - ss.json_packer(dict(a=datetime(2021, 4, 1, 12, tzinfo=tzlocal()))) +@pytest.mark.parametrize( + ["description", "data"], + [ + ("dict", [{"a": 1}, [{"a": 1}]]), + ("infinite", [math.inf, ["inf", None]]), + ("datetime", [datetime(2021, 4, 1, 12, tzinfo=tzlocal()), []]), + ], +) +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_serialize_objects(packer, pack, unpack, description, data): + data_in, data_out_options = data with warnings.catch_warnings(): warnings.simplefilter("ignore") - ss.json_packer(dict(a=math.inf)) + value = pack(data_in) + unpacked = unpack(value) + if (description == "infinite") and (packer in ["pickle", "msgpack"]): + assert math.isinf(unpacked) + elif description == "datetime": + assert data_in == jsonutil.parse_date(unpacked) + else: + assert unpacked in data_out_options + + +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_cannot_serialize(session, packer, pack, unpack): + data = {"a": session} + with pytest.raises((TypeError, ValueError, PicklingError)): + pack(data) + + +@pytest.mark.parametrize("mode", ["packer", "unpacker"]) +@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers) +def test_pack_unpack(session, packer, pack, unpack, mode): + s: ss.Session = session + s.set_trait(mode, packer) + assert s.pack is pack + assert s.unpack is unpack + mode_reverse = "unpacker" if mode == "packer" else "packer" + assert getattr(s, mode_reverse) == packer def test_message_cls():