Skip to content

Commit a1dfdc1

Browse files
IvyZXGoogle-ML-Automation
authored andcommitted
C++ tree with path API
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening. * Moves all the key classes down to C++ level, while keeping the APIs unchanged. * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy. * Registered defaultdict and ordereddict via the keypath API now. PiperOrigin-RevId: 701613257
1 parent db4b3f2 commit a1dfdc1

File tree

3 files changed

+288
-45
lines changed

3 files changed

+288
-45
lines changed

jax/_src/tree_util.py

Lines changed: 82 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
from functools import partial
2323
import operator as op
2424
import textwrap
25-
from typing import Any, NamedTuple, TypeVar, Union, overload
25+
from typing import Any, NamedTuple, TypeVar, overload
2626

2727
from jax._src import traceback_util
2828
from jax._src.lib import pytree
29+
from jax._src.lib import xla_extension_version
2930
from jax._src.util import safe_zip, set_module
3031
from jax._src.util import unzip2
3132

@@ -209,12 +210,21 @@ def all_leaves(iterable: Iterable[Any],
209210

210211
_Children = TypeVar("_Children", bound=Iterable[Any])
211212
_AuxData = TypeVar("_AuxData", bound=Hashable)
213+
KeyEntry = TypeVar("KeyEntry", bound=Any)
214+
KeyLeafPair = tuple[KeyEntry, Any]
215+
KeyLeafPairs = Iterable[KeyLeafPair]
216+
KeyPath = tuple[KeyEntry, ...]
212217

213218

214219
@export
215-
def register_pytree_node(nodetype: type[T],
216-
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
217-
unflatten_func: Callable[[_AuxData, _Children], T]) -> None:
220+
def register_pytree_node(
221+
nodetype: type[T],
222+
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
223+
unflatten_func: Callable[[_AuxData, _Children], T],
224+
flatten_with_keys_func: (
225+
Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None
226+
) = None,
227+
) -> None:
218228
"""Extends the set of types that are considered internal nodes in pytrees.
219229
220230
See :ref:`example usage <pytrees>`.
@@ -279,9 +289,20 @@ def register_pytree_node(nodetype: type[T],
279289
>>> jax.jit(f)(m)
280290
Array([1., 2., 3., 4., 5.], dtype=float32)
281291
"""
282-
default_registry.register_node(nodetype, flatten_func, unflatten_func)
283-
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
284-
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
292+
if xla_extension_version >= 299:
293+
default_registry.register_node( # type: ignore[call-arg]
294+
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
295+
)
296+
none_leaf_registry.register_node( # type: ignore[call-arg]
297+
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
298+
)
299+
dispatch_registry.register_node( # type: ignore[call-arg]
300+
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
301+
)
302+
else:
303+
default_registry.register_node(nodetype, flatten_func, unflatten_func)
304+
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
305+
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
285306
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
286307

287308

@@ -452,21 +473,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool
452473
return all(tree_leaves(tree, is_leaf=is_leaf))
453474

454475

455-
register_pytree_node(
456-
collections.OrderedDict,
457-
lambda x: (tuple(x.values()), tuple(x.keys())),
458-
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))
459-
460-
def _flatten_defaultdict(d):
461-
keys = tuple(sorted(d))
462-
return tuple(d[k] for k in keys), (d.default_factory, keys)
463-
464-
register_pytree_node(
465-
collections.defaultdict,
466-
_flatten_defaultdict,
467-
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))
468-
469-
470476
class _HashableCallableShim:
471477
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
472478

@@ -578,11 +584,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
578584

579585

580586
# flatten_one_level is not exported.
581-
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
587+
def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
582588
"""Flatten the given pytree node by one level.
583589
584590
Args:
585-
pytree: A valid pytree node, either built-in or registered via
591+
tree: A valid pytree node, either built-in or registered via
586592
:func:`register_pytree_node` or related functions.
587593
588594
Returns:
@@ -601,9 +607,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
601607
>>> meta
602608
('a', 'b')
603609
"""
604-
out = default_registry.flatten_one_level(pytree)
610+
out = default_registry.flatten_one_level(tree)
605611
if out is None:
606-
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
612+
raise ValueError(f"can't tree-flatten type: {type(tree)}")
607613
else:
608614
return out
609615

@@ -739,10 +745,12 @@ class FlattenedIndexKey():
739745
def __str__(self):
740746
return f'[<flat index {self.key}>]'
741747

742-
BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]
743748

744-
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
745-
KeyPath = tuple[KeyEntry, ...]
749+
if xla_extension_version >= 299:
750+
SequenceKey = pytree.SequenceKey # type: ignore
751+
DictKey = pytree.DictKey # type: ignore
752+
GetAttrKey = pytree.GetAttrKey # type: ignore
753+
FlattenedIndexKey = pytree.FlattenedIndexKey # type: ignore
746754

747755

748756
@export
@@ -764,6 +772,7 @@ def keystr(keys: KeyPath):
764772
return ''.join(map(str, keys))
765773

766774

775+
# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
767776
class _RegistryWithKeypathsEntry(NamedTuple):
768777
flatten_with_keys: Callable[..., Any]
769778
unflatten_func: Callable[..., Any]
@@ -780,7 +789,6 @@ def flatten_with_keys(xs):
780789
flatten_with_keys, _registry[ty].from_iter
781790
)
782791

783-
784792
_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {}
785793

786794
_register_keypaths(
@@ -803,13 +811,9 @@ def flatten_with_keys(xs):
803811
@export
804812
def register_pytree_with_keys(
805813
nodetype: type[T],
806-
flatten_with_keys: Callable[
807-
[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]
808-
],
814+
flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]],
809815
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
810-
flatten_func: None | (
811-
Callable[[T], tuple[Iterable[Any], _AuxData]]
812-
) = None,
816+
flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None,
813817
):
814818
"""Extends the set of types that are considered internal nodes in pytrees.
815819
@@ -870,7 +874,9 @@ def flatten_func_impl(tree):
870874
return [c for _, c in key_children], treedef
871875
flatten_func = flatten_func_impl
872876

873-
register_pytree_node(nodetype, flatten_func, unflatten_func)
877+
register_pytree_node(
878+
nodetype, flatten_func, unflatten_func, flatten_with_keys
879+
)
874880
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
875881
flatten_with_keys, unflatten_func
876882
)
@@ -1092,6 +1098,40 @@ def flatten_func(x):
10921098
return nodetype
10931099

10941100

1101+
if xla_extension_version >= 299:
1102+
register_pytree_with_keys(
1103+
collections.OrderedDict,
1104+
lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())),
1105+
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
1106+
)
1107+
1108+
def _flatten_defaultdict_with_keys(d):
1109+
keys = tuple(sorted(d))
1110+
return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys)
1111+
1112+
register_pytree_with_keys(
1113+
collections.defaultdict,
1114+
_flatten_defaultdict_with_keys,
1115+
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
1116+
)
1117+
else:
1118+
register_pytree_node(
1119+
collections.OrderedDict,
1120+
lambda x: (tuple(x.values()), tuple(x.keys())),
1121+
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
1122+
)
1123+
1124+
def _flatten_defaultdict(d):
1125+
keys = tuple(sorted(d))
1126+
return tuple(d[k] for k in keys), (d.default_factory, keys)
1127+
1128+
register_pytree_node(
1129+
collections.defaultdict,
1130+
_flatten_defaultdict,
1131+
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
1132+
)
1133+
1134+
10951135
@export
10961136
def register_static(cls: type[H]) -> type[H]:
10971137
"""Registers `cls` as a pytree with no leaves.
@@ -1144,6 +1184,8 @@ def tree_flatten_with_path(
11441184
which contains a leaf and its key path. The second element is a treedef
11451185
representing the structure of the flattened tree.
11461186
"""
1187+
if xla_extension_version >= 299:
1188+
return default_registry.flatten_with_path(tree, is_leaf)
11471189
_, tree_def = tree_flatten(tree, is_leaf)
11481190
return _generate_key_paths(tree, is_leaf), tree_def
11491191

@@ -1164,13 +1206,15 @@ def tree_leaves_with_path(
11641206
- :func:`jax.tree_util.tree_leaves`
11651207
- :func:`jax.tree_util.tree_flatten_with_path`
11661208
"""
1167-
return _generate_key_paths(tree, is_leaf)
1209+
return tree_flatten_with_path(tree, is_leaf)[0]
11681210

11691211

11701212
# generate_key_paths is not exported.
11711213
def generate_key_paths(
11721214
tree: Any, is_leaf: Callable[[Any], bool] | None = None
11731215
) -> list[tuple[KeyPath, Any]]:
1216+
if xla_extension_version >= 299:
1217+
return tree_leaves_with_path(tree, is_leaf)
11741218
return list(_generate_key_paths_((), tree, is_leaf))
11751219
_generate_key_paths = generate_key_paths # alias for backward compat
11761220

tests/package_structure_test.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,53 @@ class PackageStructureTest(jtu.JaxTestCase):
3434
_mod("jax.errors", exclude=["JaxRuntimeError"]),
3535
_mod(
3636
"jax.numpy",
37-
exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating",
38-
"dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo",
39-
"flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim",
40-
"number", "object_", "printoptions", "save", "savez", "set_printoptions",
41-
"shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"]
37+
exclude=[
38+
"array_repr",
39+
"array_str",
40+
"can_cast",
41+
"character",
42+
"complexfloating",
43+
"dtype",
44+
"iinfo",
45+
"index_exp",
46+
"inexact",
47+
"integer",
48+
"iterable",
49+
"finfo",
50+
"flexible",
51+
"floating",
52+
"generic",
53+
"get_printoptions",
54+
"ndarray",
55+
"ndim",
56+
"number",
57+
"object_",
58+
"printoptions",
59+
"save",
60+
"savez",
61+
"set_printoptions",
62+
"shape",
63+
"signedinteger",
64+
"size",
65+
"s_",
66+
"unsignedinteger",
67+
"ComplexWarning",
68+
],
4269
),
4370
_mod("jax.numpy.linalg"),
4471
_mod("jax.nn.initializers"),
4572
_mod(
4673
"jax.tree_util",
47-
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
74+
exclude=[
75+
"PyTreeDef",
76+
"default_registry",
77+
"KeyEntry",
78+
"KeyPath",
79+
"DictKey",
80+
"GetAttrKey",
81+
"SequenceKey",
82+
"FlattenedIndexKey",
83+
],
4884
),
4985
])
5086
def test_exported_names_match_module(self, module_name, include, exclude):

0 commit comments

Comments
 (0)