Skip to content

Commit 814b32a

Browse files
committed
tree_all: add support for is_leaf
1 parent 0739d52 commit 814b32a

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Remember to align the itemized text with the first line of an item within a list
1818
`from jax.experimental.export import export`, and instead you should use
1919
`from jax.experimental import export`.
2020
The removed functionality has been deprecated since 0.4.24.
21+
* Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.
2122

2223
* Deprecations
2324
* `jax.sharding.XLACompatibleSharding` is deprecated. Please use

jax/_src/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
T = TypeVar("T")
2121

2222

23-
def all(tree: Any) -> bool:
23+
def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
2424
"""Alias of :func:`jax.tree_util.tree_all`."""
25-
return tree_util.tree_all(tree)
25+
return tree_util.tree_all(tree, is_leaf=is_leaf)
2626

2727

2828
def flatten(tree: Any,

jax/_src/tree_util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,11 +623,15 @@ def tree_reduce(function: Callable[[T, Any], T],
623623

624624

625625
@export
626-
def tree_all(tree: Any) -> bool:
626+
def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
627627
"""Call all() over the leaves of a tree.
628628
629629
Args:
630630
tree: the pytree to evaluate
631+
is_leaf : an optionally specified function that will be called at each
632+
flattening step. It should return a boolean, which indicates whether the
633+
flattening should traverse the current object, or if it should be stopped
634+
immediately, with the whole subtree being treated as a leaf.
631635
632636
Returns:
633637
result: boolean True or False
@@ -643,7 +647,7 @@ def tree_all(tree: Any) -> bool:
643647
- :func:`jax.tree_util.tree_reduce`
644648
- :func:`jax.tree_util.tree_leaves`
645649
"""
646-
return all(tree_leaves(tree))
650+
return all(tree_leaves(tree, is_leaf=is_leaf))
647651

648652

649653
register_pytree_node(

tests/tree_util_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,20 +1153,44 @@ def test_tree_all(self):
11531153
tree_util.tree_all(obj),
11541154
)
11551155

1156+
def test_tree_all_is_leaf(self):
1157+
obj = [True, True, (True, False)]
1158+
is_leaf = lambda x: isinstance(x, tuple)
1159+
self.assertEqual(
1160+
jax.tree.all(obj, is_leaf=is_leaf),
1161+
tree_util.tree_all(obj, is_leaf=is_leaf),
1162+
)
1163+
11561164
def test_tree_flatten(self):
11571165
obj = [1, 2, (3, 4)]
11581166
self.assertEqual(
11591167
jax.tree.flatten(obj),
11601168
tree_util.tree_flatten(obj),
11611169
)
11621170

1171+
def test_tree_flatten_is_leaf(self):
1172+
obj = [1, 2, (3, 4)]
1173+
is_leaf = lambda x: isinstance(x, tuple)
1174+
self.assertEqual(
1175+
jax.tree.flatten(obj, is_leaf=is_leaf),
1176+
tree_util.tree_flatten(obj, is_leaf=is_leaf),
1177+
)
1178+
11631179
def test_tree_leaves(self):
11641180
obj = [1, 2, (3, 4)]
11651181
self.assertEqual(
11661182
jax.tree.leaves(obj),
11671183
tree_util.tree_leaves(obj),
11681184
)
11691185

1186+
def test_tree_leaves_is_leaf(self):
1187+
obj = [1, 2, (3, 4)]
1188+
is_leaf = lambda x: isinstance(x, tuple)
1189+
self.assertEqual(
1190+
jax.tree.leaves(obj, is_leaf=is_leaf),
1191+
tree_util.tree_leaves(obj, is_leaf=is_leaf),
1192+
)
1193+
11701194
def test_tree_map(self):
11711195
func = lambda x: x * 2
11721196
obj = [1, 2, (3, 4)]
@@ -1175,6 +1199,15 @@ def test_tree_map(self):
11751199
tree_util.tree_map(func, obj),
11761200
)
11771201

1202+
def test_tree_map_is_leaf(self):
1203+
func = lambda x: x * 2
1204+
obj = [1, 2, (3, 4)]
1205+
is_leaf = lambda x: isinstance(x, tuple)
1206+
self.assertEqual(
1207+
jax.tree.map(func, obj, is_leaf=is_leaf),
1208+
tree_util.tree_map(func, obj, is_leaf=is_leaf),
1209+
)
1210+
11781211
def test_tree_reduce(self):
11791212
func = lambda a, b: a + b
11801213
obj = [1, 2, (3, 4)]
@@ -1183,13 +1216,30 @@ def test_tree_reduce(self):
11831216
tree_util.tree_reduce(func, obj),
11841217
)
11851218

1219+
def test_tree_reduce_is_leaf(self):
1220+
func = lambda a, b: a + b
1221+
obj = [(1, 2), (3, 4)]
1222+
is_leaf = lambda x: isinstance(x, tuple)
1223+
self.assertEqual(
1224+
jax.tree.reduce(func, obj, is_leaf=is_leaf),
1225+
tree_util.tree_reduce(func, obj, is_leaf=is_leaf),
1226+
)
1227+
11861228
def test_tree_structure(self):
11871229
obj = [1, 2, (3, 4)]
11881230
self.assertEqual(
11891231
jax.tree.structure(obj),
11901232
tree_util.tree_structure(obj),
11911233
)
11921234

1235+
def test_tree_structure_is_leaf(self):
1236+
obj = [1, 2, (3, 4)]
1237+
is_leaf = lambda x: isinstance(x, tuple)
1238+
self.assertEqual(
1239+
jax.tree.structure(obj, is_leaf=is_leaf),
1240+
tree_util.tree_structure(obj, is_leaf=is_leaf),
1241+
)
1242+
11931243
def test_tree_transpose(self):
11941244
obj = [(1, 2), (3, 4), (5, 6)]
11951245
outer_treedef = tree_util.tree_structure(['*', '*', '*'])

0 commit comments

Comments
 (0)