|
17 | 17 |
|
18 | 18 | from __future__ import annotations |
19 | 19 |
|
| 20 | +import collections |
20 | 21 | from collections.abc import Callable, Sequence |
21 | 22 | import copy |
22 | 23 | import dataclasses |
23 | 24 | import functools |
24 | 25 | import itertools |
| 26 | +import json |
25 | 27 | import re |
26 | | -from typing import Any, Union, cast |
| 28 | +from typing import Any, Protocol, TypeVar, Union, cast |
27 | 29 | import warnings |
28 | 30 |
|
29 | 31 | from absl import logging |
@@ -344,6 +346,189 @@ def deserialize(blob: bytearray) -> Exported: |
344 | 346 | return deserialize(blob) |
345 | 347 |
|
346 | 348 |
|
| 349 | +T = TypeVar("T") |
| 350 | +PyTreeAuxData = Any # alias for tree_util._AuxData |
| 351 | + |
| 352 | + |
| 353 | +class _SerializeAuxData(Protocol): |
| 354 | + def __call__(self, aux_data: PyTreeAuxData) -> bytes: |
| 355 | + """Serializes the PyTree node AuxData. |
| 356 | +
|
| 357 | + The AuxData is returned by the `flatten_func` registered by |
| 358 | + `tree_util.register_pytree_node`). |
| 359 | + """ |
| 360 | + |
| 361 | + |
| 362 | +class _DeserializeAuxData(Protocol): |
| 363 | + def __call__(self, serialized_aux_data: bytes) -> PyTreeAuxData: |
| 364 | + """Deserializes the PyTree node AuxData. |
| 365 | +
|
| 366 | + The result will be passed to `_BuildFromChildren`. |
| 367 | + """ |
| 368 | + |
| 369 | + |
| 370 | +class _BuildFromChildren(Protocol): |
| 371 | + def __call__(self, aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any: |
| 372 | + """Materializes a T given a deserialized AuxData and children. |
| 373 | +
|
| 374 | + This is similar in scope with the `unflatten_func`. |
| 375 | + """ |
| 376 | + |
| 377 | + |
| 378 | +serialization_registry: dict[type, tuple[str, _SerializeAuxData]] = {} |
| 379 | + |
| 380 | + |
| 381 | +deserialization_registry: dict[ |
| 382 | + str, |
| 383 | + tuple[type, _DeserializeAuxData, _BuildFromChildren]] = {} |
| 384 | + |
| 385 | + |
| 386 | +def _is_namedtuple(nodetype: type) -> bool: |
| 387 | + return (issubclass(nodetype, tuple) and |
| 388 | + hasattr(nodetype, "_fields") and |
| 389 | + isinstance(nodetype._fields, Sequence) and |
| 390 | + all(isinstance(f, str) for f in nodetype._fields)) |
| 391 | + |
| 392 | +def register_pytree_node_serialization( |
| 393 | + nodetype: type[T], |
| 394 | + *, |
| 395 | + serialized_name: str, |
| 396 | + serialize_auxdata: _SerializeAuxData, |
| 397 | + deserialize_auxdata: _DeserializeAuxData, |
| 398 | + from_children: _BuildFromChildren | None = None |
| 399 | +) -> type[T]: |
| 400 | + """Registers a custom PyTree node for serialization and deserialization. |
| 401 | +
|
| 402 | + You must use this function before you can serialize and deserialize PyTree |
| 403 | + nodes for the types not supported natively. We serialize PyTree nodes for |
| 404 | + the `in_tree` and `out_tree` fields of `Exported`, which are part of the |
| 405 | + exported function's calling convention. |
| 406 | +
|
| 407 | + This function must be called after calling |
| 408 | + `jax.tree_util.register_pytree_node` (except for `collections.namedtuple`, |
| 409 | + which do not require a call to `register_pytree_node`). |
| 410 | +
|
| 411 | + Args: |
| 412 | + nodetype: the type whose PyTree nodes we want to serialize. It is an |
| 413 | + error to attempt to register multiple serializations for a `nodetype`. |
| 414 | + serialized_name: a string that will be present in the serialization and |
| 415 | + will be used to look up the registration during deserialization. It is an |
| 416 | + error to attempt to register multiple serializations for a |
| 417 | + `serialized_name`. |
| 418 | + serialize_auxdata: serialize the PyTree auxdata (returned by the |
| 419 | + `flatten_func` argument to `jax.tree_util.register_pytree_node`.). |
| 420 | + deserialize_auxdata: deserialize the auxdata that was serialized by the |
| 421 | + `serialize_auxdata`. |
| 422 | + from_children: if present, this is a function that takes that result of |
| 423 | + `deserialize_auxdata` along with some children and creates an instance |
| 424 | + of `nodetype`. This is similar to the `unflatten_func` passed to |
| 425 | + `jax.tree_util.register_pytree_node`. If not present, we look up |
| 426 | + and use the `unflatten_func`. This is needed for `collections.namedtuple`, |
| 427 | + which does not have a `register_pytree_node`, but it can be useful to |
| 428 | + override that function. Note that the result of `from_children` is |
| 429 | + only used with `jax.tree_util.tree_structure` to construct a proper |
| 430 | + PyTree node, it is not used to construct the outputs of the serialized |
| 431 | + function. |
| 432 | +
|
| 433 | + Returns: |
| 434 | + the same type passed as `nodetype`, so that this function can |
| 435 | + be used as a class decorator. |
| 436 | + """ |
| 437 | + if nodetype in serialization_registry: |
| 438 | + raise ValueError( |
| 439 | + f"Duplicate serialization registration for type `{nodetype}`. " |
| 440 | + "Previous registration was with serialized_name " |
| 441 | + f"`{serialization_registry[nodetype][0]}`.") |
| 442 | + if serialized_name in deserialization_registry: |
| 443 | + raise ValueError( |
| 444 | + "Duplicate serialization registration for " |
| 445 | + f"serialized_name `{serialized_name}`. " |
| 446 | + "Previous registration was for type " |
| 447 | + f"`{deserialization_registry[serialized_name][0]}`.") |
| 448 | + if from_children is None: |
| 449 | + if nodetype not in tree_util._registry: |
| 450 | + raise ValueError( |
| 451 | + f"If `from_children` is not present, you must call first" |
| 452 | + f"`jax.tree_util.register_pytree_node` for `{nodetype}`") |
| 453 | + from_children = tree_util._registry[nodetype].from_iter |
| 454 | + |
| 455 | + serialization_registry[nodetype] = ( |
| 456 | + serialized_name, serialize_auxdata) |
| 457 | + deserialization_registry[serialized_name] = ( |
| 458 | + nodetype, deserialize_auxdata, from_children) |
| 459 | + return nodetype |
| 460 | + |
| 461 | + |
| 462 | +def register_namedtuple_serialization( |
| 463 | + nodetype: type[T], |
| 464 | + *, |
| 465 | + serialized_name: str) -> type[T]: |
| 466 | + """Registers a namedtuple for serialization and deserialization. |
| 467 | +
|
| 468 | + JAX has native PyTree support for `collections.namedtuple`, and does not |
| 469 | + require a call to `jax.tree_util.register_pytree_node`. However, if you |
| 470 | + want to serialize functions that have inputs of outputs of a |
| 471 | + namedtuple type, you must register that type for serialization. |
| 472 | +
|
| 473 | + Args: |
| 474 | + nodetype: the type whose PyTree nodes we want to serialize. It is an |
| 475 | + error to attempt to register multiple serializations for a `nodetype`. |
| 476 | + On deserialization, this type must have the same set of keys that |
| 477 | + were present during serialization. |
| 478 | + serialized_name: a string that will be present in the serialization and |
| 479 | + will be used to look up the registration during deserialization. It is an |
| 480 | + error to attempt to register multiple serializations for |
| 481 | + a `serialized_name`. |
| 482 | +
|
| 483 | + Returns: |
| 484 | + the same type passed as `nodetype`, so that this function can |
| 485 | + be used as a class decorator. |
| 486 | +""" |
| 487 | + if not _is_namedtuple(nodetype): |
| 488 | + raise ValueError("Use `jax.export.register_pytree_node_serialization` for " |
| 489 | + "types other than `collections.namedtuple`.") |
| 490 | + |
| 491 | + def serialize_auxdata(aux_data: PyTreeAuxData) -> bytes: |
| 492 | + # Store the serialized keys in the serialized auxdata |
| 493 | + del aux_data |
| 494 | + return json.dumps(nodetype._fields).encode("utf-8") |
| 495 | + |
| 496 | + def deserialize_auxdata(serialized_aux_data: bytes) -> PyTreeAuxData: |
| 497 | + return json.loads(serialized_aux_data.decode("utf-8")) |
| 498 | + |
| 499 | + def from_children(aux_data: PyTreeAuxData, children: Sequence[Any]) -> Any: |
| 500 | + # Use our own "from_children" because namedtuples do not have a pytree |
| 501 | + # registration. |
| 502 | + ser_keys = cast(Sequence[str], aux_data) |
| 503 | + assert len(ser_keys) == len(children) |
| 504 | + return nodetype(** dict(zip(ser_keys, children))) |
| 505 | + |
| 506 | + return register_pytree_node_serialization( |
| 507 | + nodetype, |
| 508 | + serialized_name=serialized_name, |
| 509 | + serialize_auxdata=serialize_auxdata, |
| 510 | + deserialize_auxdata=deserialize_auxdata, |
| 511 | + from_children=from_children) |
| 512 | + |
| 513 | + |
| 514 | +# collections.OrderedDict is registered as a pytree node with auxdata being |
| 515 | +# `tuple(x.keys())`. |
| 516 | +def _serialize_ordereddict_keys(keys): |
| 517 | + if isinstance(keys, Sequence) and all(isinstance(k, str) for k in keys): |
| 518 | + return json.dumps(keys).encode("utf-8") |
| 519 | + else: |
| 520 | + raise NotImplementedError( |
| 521 | + "Serialization of collections.OrderedDict is supported only when the " |
| 522 | + f"keys are strings. Found keys: {keys}.") |
| 523 | + |
| 524 | + |
| 525 | +register_pytree_node_serialization( |
| 526 | + collections.OrderedDict, |
| 527 | + serialized_name="collections.OrderedDict", |
| 528 | + serialize_auxdata=_serialize_ordereddict_keys, |
| 529 | + deserialize_auxdata=lambda b: json.loads(b.decode("utf-8"))) |
| 530 | + |
| 531 | + |
347 | 532 | def default_export_platform() -> str: |
348 | 533 | """Retrieves the default export platform. |
349 | 534 |
|
@@ -404,9 +589,10 @@ def export_back_compat( |
404 | 589 | disabled_checks: the safety checks to disable. See docstring |
405 | 590 | of `DisabledSafetyCheck`. |
406 | 591 |
|
407 | | - Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, |
408 | | - or values with `.shape` and `.dtype` attributes, and returns an |
409 | | - `Exported`. |
| 592 | + Returns: |
| 593 | + a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, |
| 594 | + or values with `.shape` and `.dtype` attributes, and returns an |
| 595 | + `Exported`. |
410 | 596 |
|
411 | 597 | Usage: |
412 | 598 |
|
@@ -480,9 +666,10 @@ def export( |
480 | 666 | disabled_checks: the safety checks to disable. See documentation for |
481 | 667 | of `jax.export.DisabledSafetyCheck`. |
482 | 668 |
|
483 | | - Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, |
484 | | - or values with `.shape` and `.dtype` attributes, and returns an |
485 | | - `Exported`. |
| 669 | + Returns: |
| 670 | + a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, |
| 671 | + or values with `.shape` and `.dtype` attributes, and returns an |
| 672 | + `Exported`. |
486 | 673 |
|
487 | 674 | Usage: |
488 | 675 |
|
|
0 commit comments