Skip to content

Commit 9b331e2

Browse files
committed
Add orjson_packer and orjson_unpacker to speed up serialization of session messages.
1 parent 1f36fbf commit 9b331e2

File tree

4 files changed

+100
-69
lines changed

4 files changed

+100
-69
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
### Enhancements made
1010

1111
- Support psutil for finding network addresses [#1033](https://github.com/jupyter/jupyter_client/pull/1033) ([@juliangilbey](https://github.com/juliangilbey))
12+
- Add new functions `orjson_packer` and `orjson_unpacker`.
1213

1314
### Bugs fixed
1415

jupyter_client/session.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from traitlets import (
3333
Any,
3434
Bool,
35+
Callable,
3536
CBytes,
3637
CUnicode,
3738
Dict,
@@ -124,15 +125,35 @@ def json_unpacker(s: str | bytes) -> t.Any:
124125
return json.loads(s)
125126

126127

128+
try:
129+
import orjson # type:ignore[import-not-found]
130+
except ModuleNotFoundError:
131+
orjson = None
132+
_default_packer_unpacker = "json", "json"
133+
_default_pack_unpack = (json_packer, json_unpacker)
134+
else:
135+
136+
def orjson_packer(obj: t.Any) -> bytes:
137+
"""Convert a json object to a bytes using orjson with a fallback to json_packer."""
138+
try:
139+
return orjson.dumps(
140+
obj, default=json_default, option=orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z
141+
)
142+
except (TypeError, ValueError):
143+
return json_packer(obj)
144+
145+
orjson_unpacker = orjson.loads
146+
_default_packer_unpacker = "orjson", "orjson"
147+
_default_pack_unpack = (orjson_packer, orjson_unpacker)
148+
149+
127150
def pickle_packer(o: t.Any) -> bytes:
128151
"""Pack an object using the pickle module."""
129152
return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
130153

131154

132155
pickle_unpacker = pickle.loads
133156

134-
default_packer = json_packer
135-
default_unpacker = json_unpacker
136157

137158
DELIM = b"<IDS|MSG>"
138159
# singleton dummy tracker, which will always report as done
@@ -350,48 +371,43 @@ class Session(Configurable):
350371
""",
351372
)
352373

374+
# serialization traits:
353375
packer = DottedObjectName(
354-
"json",
376+
_default_packer_unpacker[0],
355377
config=True,
356378
help="""The name of the packer for serializing messages.
357379
Should be one of 'json', 'pickle', or an import name
358380
for a custom callable serializer.""",
359381
)
360-
361-
@observe("packer")
362-
def _packer_changed(self, change: t.Any) -> None:
363-
new = change["new"]
364-
if new.lower() == "json":
365-
self.pack = json_packer
366-
self.unpack = json_unpacker
367-
self.unpacker = new
368-
elif new.lower() == "pickle":
369-
self.pack = pickle_packer
370-
self.unpack = pickle_unpacker
371-
self.unpacker = new
372-
else:
373-
self.pack = import_item(str(new))
374-
375382
unpacker = DottedObjectName(
376-
"json",
383+
_default_packer_unpacker[1],
377384
config=True,
378385
help="""The name of the unpacker for unserializing messages.
379386
Only used with custom functions for `packer`.""",
380387
)
388+
pack = Callable(_default_pack_unpack[0]) # the actual packer function
389+
unpack = Callable(_default_pack_unpack[1]) # the actual unpacker function
381390

382-
@observe("unpacker")
383-
def _unpacker_changed(self, change: t.Any) -> None:
391+
@observe("packer", "unpacker")
392+
def _packer_unpacker_changed(self, change: t.Any) -> None:
384393
new = change["new"]
385-
if new.lower() == "json":
394+
new_ = new.lower()
395+
if new_ == "orjson" and orjson:
396+
self.pack = orjson_packer
397+
self.unpack = orjson_unpacker
398+
self.packer = self.unpacker = new
399+
elif new_ in ["json", "orjson"]:
386400
self.pack = json_packer
387401
self.unpack = json_unpacker
388-
self.packer = new
389-
elif new.lower() == "pickle":
402+
self.packer = self.unpacker = new
403+
elif new_ == "pickle":
390404
self.pack = pickle_packer
391405
self.unpack = pickle_unpacker
392-
self.packer = new
406+
self.packer = self.unpacker = new
393407
else:
394-
self.unpack = import_item(str(new))
408+
obj = import_item(str(new))
409+
name = "pack" if change["name"] == "packer" else "unpack"
410+
self.set_trait(name, obj)
395411

396412
session = CUnicode("", config=True, help="""The UUID identifying this session.""")
397413

@@ -416,8 +432,7 @@ def _session_changed(self, change: t.Any) -> None:
416432
metadata = Dict(
417433
{},
418434
config=True,
419-
help="Metadata dictionary, which serves as the default top-level metadata dict for each "
420-
"message.",
435+
help="Metadata dictionary, which serves as the default top-level metadata dict for each message.",
421436
)
422437

423438
# if 0, no adapting to do.
@@ -486,25 +501,6 @@ def _keyfile_changed(self, change: t.Any) -> None:
486501
# for protecting against sends from forks
487502
pid = Integer()
488503

489-
# serialization traits:
490-
491-
pack = Any(default_packer) # the actual packer function
492-
493-
@observe("pack")
494-
def _pack_changed(self, change: t.Any) -> None:
495-
new = change["new"]
496-
if not callable(new):
497-
raise TypeError("packer must be callable, not %s" % type(new))
498-
499-
unpack = Any(default_unpacker) # the actual packer function
500-
501-
@observe("unpack")
502-
def _unpack_changed(self, change: t.Any) -> None:
503-
# unpacker is not checked - it is assumed to be
504-
new = change["new"]
505-
if not callable(new):
506-
raise TypeError("unpacker must be callable, not %s" % type(new))
507-
508504
# thresholds:
509505
copy_threshold = Integer(
510506
2**16,
@@ -514,8 +510,7 @@ def _unpack_changed(self, change: t.Any) -> None:
514510
buffer_threshold = Integer(
515511
MAX_BYTES,
516512
config=True,
517-
help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid "
518-
"pickling.",
513+
help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.",
519514
)
520515
item_threshold = Integer(
521516
MAX_ITEMS,
@@ -625,10 +620,7 @@ def _check_packers(self) -> None:
625620
unpacked = unpack(packed)
626621
assert unpacked == msg_list
627622
except Exception as e:
628-
msg = (
629-
f"unpacker '{self.unpacker}' could not handle output from packer"
630-
f" '{self.packer}': {e}"
631-
)
623+
msg = f"unpacker '{self.unpacker}' could not handle output from packer '{self.packer}': {e}"
632624
raise ValueError(msg) from e
633625

634626
# check datetime support

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ requires-python = ">=3.9"
2222
dependencies = [
2323
"importlib_metadata>=4.8.3;python_version<\"3.10\"",
2424
"jupyter_core>=4.12,!=5.0.*",
25+
"orjson>=3.10.18; implementation_name == 'cpython'",
2526
"python-dateutil>=2.8.2",
2627
"pyzmq>=23.0",
2728
"tornado>=6.2",

tests/test_session.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import zmq
1515
from dateutil.tz import tzlocal
1616
from tornado import ioloop
17+
from traitlets import TraitError
1718
from zmq.eventloop.zmqstream import ZMQStream
1819

1920
from jupyter_client import jsonutil
@@ -40,6 +41,14 @@ def session():
4041
return ss.Session()
4142

4243

44+
serializers = [
45+
("json", ss.json_packer, ss.json_unpacker),
46+
("pickle", ss.pickle_packer, ss.pickle_unpacker),
47+
]
48+
if ss.orjson:
49+
serializers.append(("orjson", ss.orjson_packer, ss.orjson_unpacker))
50+
51+
4352
@pytest.mark.usefixtures("no_copy_threshold")
4453
class TestSession:
4554
def assertEqual(self, a, b):
@@ -63,7 +72,11 @@ def test_msg(self, session):
6372
self.assertEqual(msg["header"]["msg_type"], "execute")
6473
self.assertEqual(msg["msg_type"], "execute")
6574

66-
def test_serialize(self, session):
75+
@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
76+
def test_serialize(self, session, packer, pack, unpack):
77+
session.packer = packer
78+
assert session.pack is pack
79+
assert session.unpack is unpack
6780
msg = session.msg("execute", content=dict(a=10, b=1.1))
6881
msg_list = session.serialize(msg, ident=b"foo")
6982
ident, msg_list = session.feed_identities(msg_list)
@@ -233,16 +246,16 @@ async def test_send(self, session):
233246
def test_args(self, session):
234247
"""initialization arguments for Session"""
235248
s = session
236-
self.assertTrue(s.pack is ss.default_packer)
237-
self.assertTrue(s.unpack is ss.default_unpacker)
249+
assert s.pack is ss._default_pack_unpack[0]
250+
assert s.unpack is ss._default_pack_unpack[1]
238251
self.assertEqual(s.username, os.environ.get("USER", "username"))
239252

240253
s = ss.Session()
241254
self.assertEqual(s.username, os.environ.get("USER", "username"))
242255

243-
with pytest.raises(TypeError):
256+
with pytest.raises(TraitError):
244257
ss.Session(pack="hi")
245-
with pytest.raises(TypeError):
258+
with pytest.raises(TraitError):
246259
ss.Session(unpack="hi")
247260
u = str(uuid.uuid4())
248261
s = ss.Session(username="carrot", session=u)
@@ -490,11 +503,6 @@ async def test_send_raw(self, session):
490503
B.close()
491504
ctx.term()
492505

493-
def test_set_packer(self, session):
494-
s = session
495-
s.packer = "json"
496-
s.unpacker = "json"
497-
498506
def test_clone(self, session):
499507
s = session
500508
s._add_digest("initial")
@@ -514,14 +522,43 @@ def test_squash_unicode():
514522
assert ss.squash_unicode("hi") == b"hi"
515523

516524

517-
def test_json_packer():
518-
ss.json_packer(dict(a=1))
519-
with pytest.raises(ValueError):
520-
ss.json_packer(dict(a=ss.Session()))
521-
ss.json_packer(dict(a=datetime(2021, 4, 1, 12, tzinfo=tzlocal())))
525+
@pytest.mark.parametrize(
526+
["description", "data"],
527+
[
528+
("dict", [{"a": 1}, [{"a": 1}]]),
529+
("infinite", [math.inf, ["inf", None]]),
530+
("datetime", [datetime(2021, 4, 1, 12, tzinfo=tzlocal()), ["2021-04-01T12:00:00+11:00"]]),
531+
],
532+
)
533+
@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
534+
def test_serialize_objects(packer, pack, unpack, description, data):
535+
data_in, data_out_options = data
522536
with warnings.catch_warnings():
523537
warnings.simplefilter("ignore")
524-
ss.json_packer(dict(a=math.inf))
538+
value = pack(data_in)
539+
unpacked = unpack(value)
540+
if (description == "infinite") and (packer == "pickle"):
541+
assert math.isinf(unpacked)
542+
return
543+
assert unpacked in data_out_options
544+
545+
546+
@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
547+
def test_cannot_serialize(session, packer, pack, unpack):
548+
data = {"a": session}
549+
with pytest.raises((TypeError, ValueError)):
550+
pack(data)
551+
552+
553+
@pytest.mark.parametrize("mode", ["packer", "unpacker"])
554+
@pytest.mark.parametrize(["packer", "pack", "unpack"], serializers)
555+
def test_packer_unpacker(session, packer, pack, unpack, mode):
556+
s: ss.Session = session
557+
s.set_trait(mode, packer)
558+
assert s.pack is pack
559+
assert s.unpack is unpack
560+
mode_reverse = "unpacker" if mode == "packer" else "packer"
561+
assert getattr(s, mode_reverse) == packer
525562

526563

527564
def test_message_cls():

0 commit comments

Comments
 (0)