Skip to content

Commit 26c40fa

Browse files
IvyZXGoogle-ML-Automation
authored andcommitted
Add jax.tree shortcuts for .*_with_path calls, for convenience of users.
PiperOrigin-RevId: 705645570
1 parent ecc2673 commit 26c40fa

File tree

6 files changed

+156
-47
lines changed

6 files changed

+156
-47
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1212

1313
## jax 0.4.38
1414

15+
* Changes:
16+
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
17+
as shortcuts of the corresponding `tree_util` functions.
18+
1519
* Deprecations
1620
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
1721
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,

docs/jax.tree.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ List of Functions
1313

1414
all
1515
flatten
16+
flatten_with_path
1617
leaves
18+
leaves_with_path
1719
map
20+
map_with_path
1821
reduce
1922
structure
2023
transpose

jax/_src/tree.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,97 @@ def unflatten(treedef: tree_util.PyTreeDef,
284284
- :func:`jax.tree.structure`
285285
"""
286286
return tree_util.tree_unflatten(treedef, leaves)
287+
288+
289+
def flatten_with_path(
290+
tree: Any, is_leaf: Callable[[Any], bool] | None = None
291+
) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]:
292+
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
293+
294+
Args:
295+
tree: a pytree to flatten. If it contains a custom type, it is recommended
296+
to be registered with ``register_pytree_with_keys``.
297+
298+
Returns:
299+
A pair which the first element is a list of key-leaf pairs, each of
300+
which contains a leaf and its key path. The second element is a treedef
301+
representing the structure of the flattened tree.
302+
303+
Examples:
304+
>>> import jax
305+
>>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
306+
>>> path_vals
307+
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
308+
>>> treedef
309+
PyTreeDef([*, {'x': *}])
310+
311+
See Also:
312+
- :func:`jax.tree.flatten`
313+
- :func:`jax.tree.map_with_path`
314+
- :func:`jax.tree_util.register_pytree_with_keys`
315+
"""
316+
return tree_util.tree_flatten_with_path(tree, is_leaf)
317+
318+
319+
def leaves_with_path(
320+
tree: Any, is_leaf: Callable[[Any], bool] | None = None
321+
) -> list[tuple[tree_util.KeyPath, Any]]:
322+
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
323+
324+
Args:
325+
tree: a pytree. If it contains a custom type, it is recommended to be
326+
registered with ``register_pytree_with_keys``.
327+
328+
Returns:
329+
A list of key-leaf pairs, each of which contains a leaf and its key path.
330+
331+
Examples:
332+
>>> import jax
333+
>>> jax.tree.leaves_with_path([1, {'x': 3}])
334+
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
335+
336+
See Also:
337+
- :func:`jax.tree.leaves`
338+
- :func:`jax.tree.flatten_with_path`
339+
- :func:`jax.tree_util.register_pytree_with_keys`
340+
"""
341+
return tree_util.tree_leaves_with_path(tree, is_leaf)
342+
343+
344+
def map_with_path(
345+
f: Callable[..., Any],
346+
tree: Any,
347+
*rest: Any,
348+
is_leaf: Callable[[Any], bool] | None = None,
349+
) -> Any:
350+
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
351+
352+
This is a more powerful alternative of ``tree_map`` that can take the key path
353+
of each leaf as input argument as well.
354+
355+
Args:
356+
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
357+
each corresponding leaves of the pytrees.
358+
tree: a pytree to be mapped over, with each leaf's key path as the first
359+
positional argument and the leaf itself as the second argument to ``f``.
360+
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
361+
or has ``tree`` as a prefix.
362+
363+
Returns:
364+
A new pytree with the same structure as ``tree`` but with the value at each
365+
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
366+
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
367+
the tuple of values at corresponding nodes in ``rest``.
368+
369+
Examples:
370+
>>> import jax
371+
>>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
372+
[1, 3, 5]
373+
374+
See Also:
375+
- :func:`jax.tree.map`
376+
- :func:`jax.tree.flatten_with_path`
377+
- :func:`jax.tree.leaves_with_path`
378+
- :func:`jax.tree_util.register_pytree_with_keys`
379+
"""
380+
return tree_util.tree_map_with_path(f, tree, *rest, is_leaf=is_leaf)

jax/_src/tree_util.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,35 +1113,15 @@ def register_static(cls: type[H]) -> type[H]:
11131113
def tree_flatten_with_path(
11141114
tree: Any, is_leaf: Callable[[Any], bool] | None = None
11151115
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
1116-
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
1117-
1118-
Args:
1119-
tree: a pytree to flatten. If it contains a custom type, it must be
1120-
registered with ``register_pytree_with_keys``.
1121-
Returns:
1122-
A pair which the first element is a list of key-leaf pairs, each of
1123-
which contains a leaf and its key path. The second element is a treedef
1124-
representing the structure of the flattened tree.
1125-
"""
1116+
"""Alias of :func:`jax.tree.flatten_with_path`."""
11261117
return default_registry.flatten_with_path(tree, is_leaf)
11271118

11281119

11291120
@export
11301121
def tree_leaves_with_path(
11311122
tree: Any, is_leaf: Callable[[Any], bool] | None = None
11321123
) -> list[tuple[KeyPath, Any]]:
1133-
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
1134-
1135-
Args:
1136-
tree: a pytree. If it contains a custom type, it must be registered with
1137-
``register_pytree_with_keys``.
1138-
Returns:
1139-
A list of key-leaf pairs, each of which contains a leaf and its key path.
1140-
1141-
See Also:
1142-
- :func:`jax.tree_util.tree_leaves`
1143-
- :func:`jax.tree_util.tree_flatten_with_path`
1144-
"""
1124+
"""Alias of :func:`jax.tree.leaves_with_path`."""
11451125
return tree_flatten_with_path(tree, is_leaf)[0]
11461126

11471127

@@ -1157,31 +1137,7 @@ def generate_key_paths(
11571137
def tree_map_with_path(f: Callable[..., Any],
11581138
tree: Any, *rest: Any,
11591139
is_leaf: Callable[[Any], bool] | None = None) -> Any:
1160-
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
1161-
1162-
This is a more powerful alternative of ``tree_map`` that can take the key path
1163-
of each leaf as input argument as well.
1164-
1165-
Args:
1166-
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
1167-
each corresponding leaves of the pytrees.
1168-
tree: a pytree to be mapped over, with each leaf's key path as the first
1169-
positional argument and the leaf itself as the second argument to ``f``.
1170-
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
1171-
or has ``tree`` as a prefix.
1172-
1173-
Returns:
1174-
A new pytree with the same structure as ``tree`` but with the value at each
1175-
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
1176-
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
1177-
the tuple of values at corresponding nodes in ``rest``.
1178-
1179-
See Also:
1180-
- :func:`jax.tree_util.tree_map`
1181-
- :func:`jax.tree_util.tree_flatten_with_path`
1182-
- :func:`jax.tree_util.tree_leaves_with_path`
1183-
"""
1184-
1140+
"""Alias of :func:`jax.tree.map_with_path`."""
11851141
keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf)
11861142
keypath_leaves = list(zip(*keypath_leaves))
11871143
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]

jax/tree.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919

2020
from jax._src.tree import (
2121
all as all,
22+
flatten_with_path as flatten_with_path,
2223
flatten as flatten,
24+
leaves_with_path as leaves_with_path,
2325
leaves as leaves,
26+
map_with_path as map_with_path,
2427
map as map,
2528
reduce as reduce,
2629
structure as structure,

tests/tree_util_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,55 @@ def test_tree_unflatten(self):
14261426
tree_util.tree_unflatten(treedef, leaves)
14271427
)
14281428

1429+
def test_tree_flatten_with_path(self):
1430+
obj = [1, 2, (3, 4)]
1431+
self.assertEqual(
1432+
jax.tree.flatten_with_path(obj),
1433+
tree_util.tree_flatten_with_path(obj),
1434+
)
1435+
1436+
def test_tree_flatten_with_path_is_leaf(self):
1437+
obj = [1, 2, (3, 4)]
1438+
is_leaf = lambda x: isinstance(x, tuple)
1439+
self.assertEqual(
1440+
jax.tree.flatten_with_path(obj, is_leaf=is_leaf),
1441+
tree_util.tree_flatten_with_path(obj, is_leaf=is_leaf),
1442+
)
1443+
1444+
def test_tree_leaves_with_path(self):
1445+
obj = [1, 2, (3, 4)]
1446+
self.assertEqual(
1447+
jax.tree.leaves_with_path(obj),
1448+
tree_util.tree_leaves_with_path(obj),
1449+
)
1450+
1451+
def test_tree_leaves_with_path_is_leaf(self):
1452+
obj = [1, 2, (3, 4)]
1453+
is_leaf = lambda x: isinstance(x, tuple)
1454+
self.assertEqual(
1455+
jax.tree.leaves_with_path(obj, is_leaf=is_leaf),
1456+
tree_util.tree_leaves_with_path(obj, is_leaf=is_leaf),
1457+
)
1458+
1459+
def test_tree_map_with_path(self):
1460+
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
1461+
obj = [1, 2, (3, 4)]
1462+
obj2 = [5, 6, (7, 8)]
1463+
self.assertEqual(
1464+
jax.tree.map_with_path(func, obj, obj2),
1465+
tree_util.tree_map_with_path(func, obj, obj2),
1466+
)
1467+
1468+
def test_tree_map_with_path_is_leaf(self):
1469+
func = lambda kp, x, y: (sum(k.idx for k in kp), x + y)
1470+
obj = [1, 2, (3, 4)]
1471+
obj2 = [5, 6, (7, 8)]
1472+
is_leaf = lambda x: isinstance(x, tuple)
1473+
self.assertEqual(
1474+
jax.tree.map_with_path(func, obj, obj2, is_leaf=is_leaf),
1475+
tree_util.tree_map_with_path(func, obj, obj2, is_leaf=is_leaf),
1476+
)
1477+
14291478

14301479
class RegistrationTest(jtu.JaxTestCase):
14311480

0 commit comments

Comments
 (0)