Skip to content

Commit 2feea41

Browse files
committed
[export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization` and `jax._src.export.register_namedtuple_serialization`. Serialization of PyTree nodes is needed to serialize the `in_tree` and `out_tree` fields of `Exported` functions (not to serialize actual instances of the custom types). When writing this I have looked at how TensorFlow handles namedtuple. It does so transparently, without requiring the user to register a serialization handler for the namedtuple type. But this has the disadvantage that on deserializaton a fresh distinct namedtuple type is created for each input and output type of the serialized function. This means that calling the deserialized function will return outputs of different types than then function that was serialized. This can be confusing. The Python pickle mode does a bit better: it attempts to look up the namedtuple type as a module attribute in the deserializing code, importing automatically the module whose name was saved during serialization. This is too much magic for my taste, as it can result in strange import errors. Hence I added an explicit step for the user to say how they want the namedtuple to be serialized and deserialized. Since I wanted to also add support for `collections.OrderedDict`, which users are asking for, I added more general support for PyTree custom nodes. Note that this registration mechanism works in conjunction with the PyTree custom node registration mechanism. The burden is on the user to decide how to serialize and deserialize the custom auxdata that the PyTree custom registration mechanism uses. Not all custom types will be serializable, but many commonly used ones, e.g., dataclasses, can now be inputs and outputs of the serialized functions.
1 parent bb271aa commit 2feea41

File tree

7 files changed

+423
-30
lines changed

7 files changed

+423
-30
lines changed

jax/_src/export/_export.py

Lines changed: 194 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
from __future__ import annotations
1919

20+
import collections
2021
from collections.abc import Callable, Sequence
2122
import copy
2223
import dataclasses
2324
import functools
2425
import itertools
26+
import json
2527
import re
26-
from typing import Any, Union, cast
28+
from typing import Any, Protocol, TypeVar, Union, cast
2729
import warnings
2830

2931
from absl import logging
@@ -344,6 +346,189 @@ def deserialize(blob: bytearray) -> Exported:
344346
return deserialize(blob)
345347

346348

349+
T = TypeVar("T")
350+
PyTreeAuxData = Any # alias for tree_util._AuxData
351+
352+
353+
class _SerializeAuxData(Protocol):
354+
def __call__(self, aux_data: PyTreeAuxData) -> bytes:
355+
"""Serializes the PyTree node AuxData.
356+
357+
The AuxData is returned by the `flatten_func` registered by
358+
`tree_util.register_pytree_node`).
359+
"""
360+
361+
362+
class _DeserializeAuxData(Protocol):
363+
def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData:
364+
"""Deserializes the PyTree node AuxData.
365+
366+
The result will be passed to `_BuildFromChildren`.
367+
"""
368+
369+
370+
class _BuildFromChildren(Protocol):
371+
def __call__(self, aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
372+
"""Materializes a T given a deserialized AuxData and children.
373+
374+
This is similar in scope with the `unflatten_func`.
375+
"""
376+
377+
378+
serialization_registry: dict[type, tuple[str, _SerializeAuxData]] = {}
379+
380+
381+
deserialization_registry: dict[
382+
str,
383+
tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {}
384+
385+
386+
def _is_namedtuple(nodetype: type) -> bool:
387+
return (issubclass(nodetype, tuple) and
388+
hasattr(nodetype, "_fields") and
389+
isinstance(nodetype._fields, Sequence) and
390+
all(isinstance(f, str) for f in nodetype._fields))
391+
392+
def register_pytree_node_serialization(
393+
nodetype: type[T],
394+
*,
395+
serialized_name: str,
396+
serialize_auxdata: _SerializeAuxData,
397+
deserialize_auxdata: _DeserializeAuxData,
398+
from_children: _BuildFromChildren | None = None
399+
) -> type[T]:
400+
"""Registers a custom PyTree node for serialization and deserialization.
401+
402+
You must use this function before you can serialize and deserialize PyTree
403+
nodes for the types not supported natively. We serialize PyTree nodes for
404+
the `in_tree` and `out_tree` fields of `Exported`, which are part of the
405+
exported function's calling convention.
406+
407+
This function must be called after calling
408+
`jax.tree_util.register_pytree_node` (except for `collections.namedtuple`,
409+
which do not require a call to `register_pytree_node`).
410+
411+
Args:
412+
nodetype: the type whose PyTree nodes we want to serialize. It is an
413+
error to attempt to register multiple serializations for a `nodetype`.
414+
serialized_name: a string that will be present in the serialization and
415+
will be used to look up the registration during deserialization. It is an
416+
error to attempt to register multiple serializations for a
417+
`serialized_name`.
418+
serialize_auxdata: serialize the PyTree auxdata (returned by the
419+
`flatten_func` argument to `jax.tree_util.register_pytree_node`.).
420+
deserialize_auxdata: deserialize the auxdata that was serialized by the
421+
`serialize_auxdata`.
422+
from_children: if present, this is a function that takes that result of
423+
`deserialize_auxdata` along with some children and creates an instance
424+
of `nodetype`. This is similar to the `unflatten_func` passed to
425+
`jax.tree_util.register_pytree_node`. If not present, we look up
426+
and use the `unflatten_func`. This is needed for `collections.namedtuple`,
427+
which does not have a `register_pytree_node`, but it can be useful to
428+
override that function. Note that the result of `from_children` is
429+
only used with `jax.tree_util.tree_structure` to construct a proper
430+
PyTree node, it is not used to construct the outputs of the serialized
431+
function.
432+
433+
Returns:
434+
the same type passed as `nodetype`, so that this function can
435+
be used as a class decorator.
436+
"""
437+
if nodetype in serialization_registry:
438+
raise ValueError(
439+
f"Duplicate serialization registration for type `{nodetype}`. "
440+
"Previous registration was with serialized_name "
441+
f"`{serialization_registry[nodetype][0]}`.")
442+
if serialized_name in deserialization_registry:
443+
raise ValueError(
444+
"Duplicate serialization registration for "
445+
f"serialized_name `{serialized_name}`. "
446+
"Previous registration was for type "
447+
f"`{deserialization_registry[serialized_name][0]}`.")
448+
if from_children is None:
449+
if nodetype not in tree_util._registry:
450+
raise ValueError(
451+
f"If `from_children` is not present, you must call first"
452+
f"`jax.tree_util.register_pytree_node` for `{nodetype}`")
453+
from_children = tree_util._registry[nodetype].from_iter
454+
455+
serialization_registry[nodetype] = (
456+
serialized_name, serialize_auxdata)
457+
deserialization_registry[serialized_name] = (
458+
nodetype, deserialize_auxdata, from_children)
459+
return nodetype
460+
461+
462+
def register_namedtuple_serialization(
463+
nodetype: type[T],
464+
*,
465+
serialized_name: str) -> type[T]:
466+
"""Registers a namedtuple for serialization and deserialization.
467+
468+
JAX has native PyTree support for `collections.namedtuple`, and does not
469+
require a call to `jax.tree_util.register_pytree_node`. However, if you
470+
want to serialize functions that have inputs of outputs of a
471+
namedtuple type, you must register that type for serialization.
472+
473+
Args:
474+
nodetype: the type whose PyTree nodes we want to serialize. It is an
475+
error to attempt to register multiple serializations for a `nodetype`.
476+
On deserialization, this type must have the same set of keys that
477+
were present during serialization.
478+
serialized_name: a string that will be present in the serialization and
479+
will be used to look up the registration during deserialization. It is an
480+
error to attempt to register multiple serializations for
481+
a `serialized_name`.
482+
483+
Returns:
484+
the same type passed as `nodetype`, so that this function can
485+
be used as a class decorator.
486+
"""
487+
if not _is_namedtuple(nodetype):
488+
raise ValueError("Use `jax.export.register_pytree_node_serialization` for "
489+
"types other than `collections.namedtuple`.")
490+
491+
def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes:
492+
# Store the serialized keys in the serialized auxdata
493+
del aux_data
494+
return json.dumps(nodetype._fields).encode("utf-8")
495+
496+
def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData:
497+
return json.loads(serialized_aux_data.decode("utf-8"))
498+
499+
def from_children(aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
500+
# Use our own "from_children" because namedtuples do not have a pytree
501+
# registration.
502+
ser_keys = cast(Sequence[str], aux_data)
503+
assert len(ser_keys) == len(children)
504+
return nodetype(** dict(zip(ser_keys, children)))
505+
506+
return register_pytree_node_serialization(
507+
nodetype,
508+
serialized_name=serialized_name,
509+
serialize_auxdata=serialize_auxdata,
510+
deserialize_auxdata=deserialize_auxdata,
511+
from_children=from_children)
512+
513+
514+
# collections.OrderedDict is registered as a pytree node with auxdata being
515+
# `tuple(x.keys())`.
516+
def _serialize_ordereddict_keys(keys):
517+
if isinstance(keys, Sequence) and all(isinstance(k, str) for k in keys):
518+
return json.dumps(keys).encode("utf-8")
519+
else:
520+
raise NotImplementedError(
521+
"Serialization of collections.OrderedDict is supported only when the "
522+
f"keys are strings. Found keys: {keys}.")
523+
524+
525+
register_pytree_node_serialization(
526+
collections.OrderedDict,
527+
serialized_name="collections.OrderedDict",
528+
serialize_auxdata=_serialize_ordereddict_keys,
529+
deserialize_auxdata=lambda b: json.loads(b.decode("utf-8")))
530+
531+
347532
def default_export_platform() -> str:
348533
"""Retrieves the default export platform.
349534
@@ -404,9 +589,10 @@ def export_back_compat(
404589
disabled_checks: the safety checks to disable. See docstring
405590
of `DisabledSafetyCheck`.
406591
407-
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
408-
or values with `.shape` and `.dtype` attributes, and returns an
409-
`Exported`.
592+
Returns:
593+
a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
594+
or values with `.shape` and `.dtype` attributes, and returns an
595+
`Exported`.
410596
411597
Usage:
412598
@@ -480,9 +666,10 @@ def export(
480666
disabled_checks: the safety checks to disable. See documentation for
481667
of `jax.export.DisabledSafetyCheck`.
482668
483-
Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
484-
or values with `.shape` and `.dtype` attributes, and returns an
485-
`Exported`.
669+
Returns:
670+
a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
671+
or values with `.shape` and `.dtype` attributes, and returns an
672+
`Exported`.
486673
487674
Usage:
488675

jax/_src/export/serialization.fbs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ enum PyTreeDefKind: byte {
2828
tuple = 2,
2929
list = 3,
3030
dict = 4,
31+
custom = 5,
3132
}
3233

3334
table PyTreeDef {
3435
kind: PyTreeDefKind;
3536
children: [PyTreeDef];
36-
children_names: [string]; // only for "dict"
37+
children_names: [string]; // only for "kind==dict"
38+
custom_name: string; // only for "kind==custom"
39+
custom_auxdata: [byte]; // only for "kind==custom"
3740
}
3841

3942
enum AbstractValueKind: byte {

jax/_src/export/serialization.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import types
1920
from collections.abc import Callable, Sequence
2021
from functools import partial
2122
from typing import TypeVar
@@ -45,6 +46,8 @@
4546
# even if the change is backwards compatible.
4647
# Version 1, Nov 2023, first version.
4748
# Version 2, Dec 16th, 2023, adds the f0 dtype.
49+
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
50+
# This version is backwards compatible with Version 2.
4851
_SERIALIZATION_VERSION = 2
4952

5053
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
@@ -152,21 +155,21 @@ def _serialize_array(
152155

153156
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
154157
serialization_version = exp.SerializationVersion()
155-
if serialization_version != _SERIALIZATION_VERSION:
158+
if serialization_version not in [2, 3]:
156159
raise NotImplementedError(
157160
f"deserialize unsupported version {serialization_version}"
158161
)
159162

160163
fun_name = exp.FunctionName().decode("utf-8")
161-
_, in_tree = tree_util.tree_flatten(
164+
in_tree = tree_util.tree_structure(
162165
_deserialize_pytreedef_to_pytree(exp.InTree())
163166
)
164167
scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints
165168
deser_aval = partial(_deserialize_aval, scope=scope)
166169
in_avals = _deserialize_tuple(
167170
exp.InAvalsLength, exp.InAvals, deser_aval
168171
)
169-
_, out_tree = tree_util.tree_flatten(
172+
out_tree = tree_util.tree_structure(
170173
_deserialize_pytreedef_to_pytree(exp.OutTree())
171174
)
172175
out_avals = _deserialize_tuple(
@@ -246,30 +249,51 @@ def _serialize_pytreedef(
246249
children_vector_offset = _serialize_array(
247250
builder, _serialize_pytreedef, children
248251
)
252+
custom_name = None
253+
custom_auxdata = None
254+
node_type = node_data and node_data[0]
249255

250256
if node_data is None: # leaf
251257
kind = ser_flatbuf.PyTreeDefKind.leaf
252-
elif node_data[0] is type(None):
258+
elif node_type is types.NoneType:
253259
kind = ser_flatbuf.PyTreeDefKind.none
254-
elif node_data[0] is tuple:
260+
elif node_type is tuple:
255261
kind = ser_flatbuf.PyTreeDefKind.tuple
256-
elif node_data[0] is list:
262+
elif node_type is list:
257263
kind = ser_flatbuf.PyTreeDefKind.list
258-
elif node_data[0] is dict:
264+
elif node_type is dict:
259265
kind = ser_flatbuf.PyTreeDefKind.dict
260266
assert len(node_data[1]) == len(children)
261267
children_names_vector_offset = _serialize_array(
262268
builder, lambda b, s: b.CreateString(s), node_data[1]
263269
)
270+
elif node_type in _export.serialization_registry:
271+
kind = ser_flatbuf.PyTreeDefKind.custom
272+
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
273+
custom_name = builder.CreateString(serialized_name)
274+
serialized_auxdata = serialize_auxdata(node_data[1])
275+
if not isinstance(serialized_auxdata, (bytes, bytearray)):
276+
raise ValueError(
277+
"The custom serialization function for `node_type` must "
278+
f"return a `bytes` object. It returned a {type(serialized_auxdata)}.")
279+
custom_auxdata = builder.CreateByteVector(serialized_auxdata)
264280
else:
265-
raise NotImplementedError(f"serializing PyTreeDef {node_data}")
281+
raise ValueError(
282+
"Cannot serialize PyTreeDef containing an "
283+
f"unregistered type `{node_type}`. "
284+
"Use `export.register_pytree_node_serialization` or "
285+
"`export.register_namedtuple_serialization`.")
266286

267287
ser_flatbuf.PyTreeDefStart(builder)
268288
ser_flatbuf.PyTreeDefAddKind(builder, kind)
269289
if children_vector_offset:
270290
ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset)
271291
if children_names_vector_offset:
272292
ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset)
293+
if custom_name is not None:
294+
ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name)
295+
if custom_auxdata is not None:
296+
ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata)
273297
return ser_flatbuf.PyTreeDefEnd(builder)
274298

275299

@@ -294,6 +318,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
294318
assert p.ChildrenNamesLength() == nr_children
295319
keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)]
296320
return dict(zip(keys, children))
321+
elif kind == ser_flatbuf.PyTreeDefKind.custom:
322+
serialized_name = p.CustomName().decode("utf-8")
323+
if serialized_name not in _export.deserialization_registry:
324+
raise ValueError(
325+
"Cannot deserialize a PyTreeDef containing an "
326+
f"unregistered type `{serialized_name}`. "
327+
"Use `export.register_pytree_node_serialization` or "
328+
"`export.register_namedtuple_serialization`.")
329+
nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name]
330+
auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes())
331+
return from_iter(auxdata, children)
297332
else:
298333
assert False, kind
299334

0 commit comments

Comments
 (0)