Skip to content

Commit f4b84e1

Browse files
Merge pull request jax-ml#24342 from gnecula:export_custom_types
PiperOrigin-RevId: 688093192
2 parents 0d7ef9c + 2feea41 commit f4b84e1

File tree

7 files changed

+423
-30
lines changed

7 files changed

+423
-30
lines changed

jax/_src/export/_export.py

Lines changed: 194 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
from __future__ import annotations
1919

20+
import collections
2021
from collections.abc import Callable, Sequence
2122
import copy
2223
import dataclasses
2324
import functools
2425
import itertools
26+
import json
2527
import re
26-
from typing import Any, Union, cast
28+
from typing import Any, Protocol, TypeVar, Union, cast
2729
import warnings
2830

2931
from absl import logging
@@ -344,6 +346,189 @@ def deserialize(blob: bytearray) -> Exported:
344346
return deserialize(blob)
345347

346348

349+
T = TypeVar("T")
350+
PyTreeAuxData = Any # alias for tree_util._AuxData
351+
352+
353+
class _SerializeAuxData(Protocol):
354+
def __call__(self, aux_data: PyTreeAuxData) -> bytes:
355+
"""Serializes the PyTree node AuxData.
356+
357+
The AuxData is returned by the `flatten_func` registered by
358+
`tree_util.register_pytree_node`).
359+
"""
360+
361+
362+
class _DeserializeAuxData(Protocol):
363+
def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData:
364+
"""Deserializes the PyTree node AuxData.
365+
366+
The result will be passed to `_BuildFromChildren`.
367+
"""
368+
369+
370+
class _BuildFromChildren(Protocol):
371+
def __call__(self, aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
372+
"""Materializes a T given a deserialized AuxData and children.
373+
374+
This is similar in scope with the `unflatten_func`.
375+
"""
376+
377+
378+
serialization_registry: dict[type, tuple[str, _SerializeAuxData]] = {}
379+
380+
381+
deserialization_registry: dict[
382+
str,
383+
tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {}
384+
385+
386+
def _is_namedtuple(nodetype: type) -> bool:
387+
return (issubclass(nodetype, tuple) and
388+
hasattr(nodetype, "_fields") and
389+
isinstance(nodetype._fields, Sequence) and
390+
all(isinstance(f, str) for f in nodetype._fields))
391+
392+
def register_pytree_node_serialization(
393+
nodetype: type[T],
394+
*,
395+
serialized_name: str,
396+
serialize_auxdata: _SerializeAuxData,
397+
deserialize_auxdata: _DeserializeAuxData,
398+
from_children: _BuildFromChildren | None = None
399+
) -> type[T]:
400+
"""Registers a custom PyTree node for serialization and deserialization.
401+
402+
You must use this function before you can serialize and deserialize PyTree
403+
nodes for the types not supported natively. We serialize PyTree nodes for
404+
the `in_tree` and `out_tree` fields of `Exported`, which are part of the
405+
exported function's calling convention.
406+
407+
This function must be called after calling
408+
`jax.tree_util.register_pytree_node` (except for `collections.namedtuple`,
409+
which do not require a call to `register_pytree_node`).
410+
411+
Args:
412+
nodetype: the type whose PyTree nodes we want to serialize. It is an
413+
error to attempt to register multiple serializations for a `nodetype`.
414+
serialized_name: a string that will be present in the serialization and
415+
will be used to look up the registration during deserialization. It is an
416+
error to attempt to register multiple serializations for a
417+
`serialized_name`.
418+
serialize_auxdata: serialize the PyTree auxdata (returned by the
419+
`flatten_func` argument to `jax.tree_util.register_pytree_node`.).
420+
deserialize_auxdata: deserialize the auxdata that was serialized by the
421+
`serialize_auxdata`.
422+
from_children: if present, this is a function that takes that result of
423+
`deserialize_auxdata` along with some children and creates an instance
424+
of `nodetype`. This is similar to the `unflatten_func` passed to
425+
`jax.tree_util.register_pytree_node`. If not present, we look up
426+
and use the `unflatten_func`. This is needed for `collections.namedtuple`,
427+
which does not have a `register_pytree_node`, but it can be useful to
428+
override that function. Note that the result of `from_children` is
429+
only used with `jax.tree_util.tree_structure` to construct a proper
430+
PyTree node, it is not used to construct the outputs of the serialized
431+
function.
432+
433+
Returns:
434+
the same type passed as `nodetype`, so that this function can
435+
be used as a class decorator.
436+
"""
437+
if nodetype in serialization_registry:
438+
raise ValueError(
439+
f"Duplicate serialization registration for type `{nodetype}`. "
440+
"Previous registration was with serialized_name "
441+
f"`{serialization_registry[nodetype][0]}`.")
442+
if serialized_name in deserialization_registry:
443+
raise ValueError(
444+
"Duplicate serialization registration for "
445+
f"serialized_name `{serialized_name}`. "
446+
"Previous registration was for type "
447+
f"`{deserialization_registry[serialized_name][0]}`.")
448+
if from_children is None:
449+
if nodetype not in tree_util._registry:
450+
raise ValueError(
451+
f"If `from_children` is not present, you must call first"
452+
f"`jax.tree_util.register_pytree_node` for `{nodetype}`")
453+
from_children = tree_util._registry[nodetype].from_iter
454+
455+
serialization_registry[nodetype] = (
456+
serialized_name, serialize_auxdata)
457+
deserialization_registry[serialized_name] = (
458+
nodetype, deserialize_auxdata, from_children)
459+
return nodetype
460+
461+
462+
def register_namedtuple_serialization(
463+
nodetype: type[T],
464+
*,
465+
serialized_name: str) -> type[T]:
466+
"""Registers a namedtuple for serialization and deserialization.
467+
468+
JAX has native PyTree support for `collections.namedtuple`, and does not
469+
require a call to `jax.tree_util.register_pytree_node`. However, if you
470+
want to serialize functions that have inputs of outputs of a
471+
namedtuple type, you must register that type for serialization.
472+
473+
Args:
474+
nodetype: the type whose PyTree nodes we want to serialize. It is an
475+
error to attempt to register multiple serializations for a `nodetype`.
476+
On deserialization, this type must have the same set of keys that
477+
were present during serialization.
478+
serialized_name: a string that will be present in the serialization and
479+
will be used to look up the registration during deserialization. It is an
480+
error to attempt to register multiple serializations for
481+
a `serialized_name`.
482+
483+
Returns:
484+
the same type passed as `nodetype`, so that this function can
485+
be used as a class decorator.
486+
"""
487+
if not _is_namedtuple(nodetype):
488+
raise ValueError("Use `jax.export.register_pytree_node_serialization` for "
489+
"types other than `collections.namedtuple`.")
490+
491+
def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes:
492+
# Store the serialized keys in the serialized auxdata
493+
del aux_data
494+
return json.dumps(nodetype._fields).encode("utf-8")
495+
496+
def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData:
497+
return json.loads(serialized_aux_data.decode("utf-8"))
498+
499+
def from_children(aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any:
500+
# Use our own "from_children" because namedtuples do not have a pytree
501+
# registration.
502+
ser_keys = cast(Sequence[str], aux_data)
503+
assert len(ser_keys) == len(children)
504+
return nodetype(** dict(zip(ser_keys, children)))
505+
506+
return register_pytree_node_serialization(
507+
nodetype,
508+
serialized_name=serialized_name,
509+
serialize_auxdata=serialize_auxdata,
510+
deserialize_auxdata=deserialize_auxdata,
511+
from_children=from_children)
512+
513+
514+
# collections.OrderedDict is registered as a pytree node with auxdata being
515+
# `tuple(x.keys())`.
516+
def _serialize_ordereddict_keys(keys):
517+
if isinstance(keys, Sequence) and all(isinstance(k, str) for k in keys):
518+
return json.dumps(keys).encode("utf-8")
519+
else:
520+
raise NotImplementedError(
521+
"Serialization of collections.OrderedDict is supported only when the "
522+
f"keys are strings. Found keys: {keys}.")
523+
524+
525+
register_pytree_node_serialization(
526+
collections.OrderedDict,
527+
serialized_name="collections.OrderedDict",
528+
serialize_auxdata=_serialize_ordereddict_keys,
529+
deserialize_auxdata=lambda b: json.loads(b.decode("utf-8")))
530+
531+
347532
def default_export_platform() -> str:
348533
"""Retrieves the default export platform.
349534
@@ -404,9 +589,10 @@ def export_back_compat(
404589
disabled_checks: the safety checks to disable. See docstring
405590
of `DisabledSafetyCheck`.
406591
407-
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
408-
or values with `.shape` and `.dtype` attributes, and returns an
409-
`Exported`.
592+
Returns:
593+
a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
594+
or values with `.shape` and `.dtype` attributes, and returns an
595+
`Exported`.
410596
411597
Usage:
412598
@@ -480,9 +666,10 @@ def export(
480666
disabled_checks: the safety checks to disable. See documentation for
481667
of `jax.export.DisabledSafetyCheck`.
482668
483-
Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
484-
or values with `.shape` and `.dtype` attributes, and returns an
485-
`Exported`.
669+
Returns:
670+
a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
671+
or values with `.shape` and `.dtype` attributes, and returns an
672+
`Exported`.
486673
487674
Usage:
488675

jax/_src/export/serialization.fbs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ enum PyTreeDefKind: byte {
2828
tuple = 2,
2929
list = 3,
3030
dict = 4,
31+
custom = 5,
3132
}
3233

3334
table PyTreeDef {
3435
kind: PyTreeDefKind;
3536
children: [PyTreeDef];
36-
children_names: [string]; // only for "dict"
37+
children_names: [string]; // only for "kind==dict"
38+
custom_name: string; // only for "kind==custom"
39+
custom_auxdata: [byte]; // only for "kind==custom"
3740
}
3841

3942
enum AbstractValueKind: byte {

jax/_src/export/serialization.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import types
1920
from collections.abc import Callable, Sequence
2021
from functools import partial
2122
from typing import TypeVar
@@ -45,6 +46,8 @@
4546
# even if the change is backwards compatible.
4647
# Version 1, Nov 2023, first version.
4748
# Version 2, Dec 16th, 2023, adds the f0 dtype.
49+
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
50+
# This version is backwards compatible with Version 2.
4851
_SERIALIZATION_VERSION = 2
4952

5053
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
@@ -152,21 +155,21 @@ def _serialize_array(
152155

153156
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
154157
serialization_version = exp.SerializationVersion()
155-
if serialization_version != _SERIALIZATION_VERSION:
158+
if serialization_version not in [2, 3]:
156159
raise NotImplementedError(
157160
f"deserialize unsupported version {serialization_version}"
158161
)
159162

160163
fun_name = exp.FunctionName().decode("utf-8")
161-
_, in_tree = tree_util.tree_flatten(
164+
in_tree = tree_util.tree_structure(
162165
_deserialize_pytreedef_to_pytree(exp.InTree())
163166
)
164167
scope = shape_poly.SymbolicScope(()) # TODO: serialize the constraints
165168
deser_aval = partial(_deserialize_aval, scope=scope)
166169
in_avals = _deserialize_tuple(
167170
exp.InAvalsLength, exp.InAvals, deser_aval
168171
)
169-
_, out_tree = tree_util.tree_flatten(
172+
out_tree = tree_util.tree_structure(
170173
_deserialize_pytreedef_to_pytree(exp.OutTree())
171174
)
172175
out_avals = _deserialize_tuple(
@@ -246,30 +249,51 @@ def _serialize_pytreedef(
246249
children_vector_offset = _serialize_array(
247250
builder, _serialize_pytreedef, children
248251
)
252+
custom_name = None
253+
custom_auxdata = None
254+
node_type = node_data and node_data[0]
249255

250256
if node_data is None: # leaf
251257
kind = ser_flatbuf.PyTreeDefKind.leaf
252-
elif node_data[0] is type(None):
258+
elif node_type is types.NoneType:
253259
kind = ser_flatbuf.PyTreeDefKind.none
254-
elif node_data[0] is tuple:
260+
elif node_type is tuple:
255261
kind = ser_flatbuf.PyTreeDefKind.tuple
256-
elif node_data[0] is list:
262+
elif node_type is list:
257263
kind = ser_flatbuf.PyTreeDefKind.list
258-
elif node_data[0] is dict:
264+
elif node_type is dict:
259265
kind = ser_flatbuf.PyTreeDefKind.dict
260266
assert len(node_data[1]) == len(children)
261267
children_names_vector_offset = _serialize_array(
262268
builder, lambda b, s: b.CreateString(s), node_data[1]
263269
)
270+
elif node_type in _export.serialization_registry:
271+
kind = ser_flatbuf.PyTreeDefKind.custom
272+
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
273+
custom_name = builder.CreateString(serialized_name)
274+
serialized_auxdata = serialize_auxdata(node_data[1])
275+
if not isinstance(serialized_auxdata, (bytes, bytearray)):
276+
raise ValueError(
277+
"The custom serialization function for `node_type` must "
278+
f"return a `bytes` object. It returned a {type(serialized_auxdata)}.")
279+
custom_auxdata = builder.CreateByteVector(serialized_auxdata)
264280
else:
265-
raise NotImplementedError(f"serializing PyTreeDef {node_data}")
281+
raise ValueError(
282+
"Cannot serialize PyTreeDef containing an "
283+
f"unregistered type `{node_type}`. "
284+
"Use `export.register_pytree_node_serialization` or "
285+
"`export.register_namedtuple_serialization`.")
266286

267287
ser_flatbuf.PyTreeDefStart(builder)
268288
ser_flatbuf.PyTreeDefAddKind(builder, kind)
269289
if children_vector_offset:
270290
ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset)
271291
if children_names_vector_offset:
272292
ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset)
293+
if custom_name is not None:
294+
ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name)
295+
if custom_auxdata is not None:
296+
ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata)
273297
return ser_flatbuf.PyTreeDefEnd(builder)
274298

275299

@@ -294,6 +318,17 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
294318
assert p.ChildrenNamesLength() == nr_children
295319
keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)]
296320
return dict(zip(keys, children))
321+
elif kind == ser_flatbuf.PyTreeDefKind.custom:
322+
serialized_name = p.CustomName().decode("utf-8")
323+
if serialized_name not in _export.deserialization_registry:
324+
raise ValueError(
325+
"Cannot deserialize a PyTreeDef containing an "
326+
f"unregistered type `{serialized_name}`. "
327+
"Use `export.register_pytree_node_serialization` or "
328+
"`export.register_namedtuple_serialization`.")
329+
nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name]
330+
auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes())
331+
return from_iter(auxdata, children)
297332
else:
298333
assert False, kind
299334

0 commit comments

Comments
 (0)