Skip to content

Commit 3ae65af

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Add tree-mode support for nnx.recursive_map, nnx.view, and nnx.view_info
Add a `graph` parameter to `nnx.recursive_map`, `nnx.view`, and `nnx.view_info`. When `graph=False`, these functions use JAX's native pytree traversal instead of Flax's graph protocol. Cycles and shared Variable/Ref references are detected and raise errors in tree mode. Added parametrized tests for both graph and tree modes. PiperOrigin-RevId: 873225517
1 parent 8f8ec30 commit 3ae65af

File tree

5 files changed

+208
-47
lines changed

5 files changed

+208
-47
lines changed

flax/nnx/graph.py

Lines changed: 80 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2907,35 +2907,37 @@ def _iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29072907

29082908

29092909
def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
2910-
seen_ids: set[int] = set()
2910+
in_progress: set[int] = set()
2911+
seen_refs: set[int] = set()
29112912
stack: list[tuple[PathParts, tp.Any, bool]] = [((), node, False)]
29122913
while stack:
29132914
path, current, traversed = stack.pop()
29142915

29152916
if traversed:
2917+
in_progress.discard(id(current))
29162918
yield path, current
29172919
continue
29182920

29192921
if not is_pytree_node(current, check_graph_registry=False):
29202922
if isinstance(current, Variable) or variablelib.is_array_ref(current):
29212923
obj_id = id(current)
2922-
if obj_id in seen_ids:
2924+
if obj_id in seen_refs:
29232925
raise ValueError(
29242926
f'Found duplicate Variable or Ref at path '
29252927
f'"{"/".join(map(str, path))}". '
29262928
'Shared references are not supported with graph=False.'
29272929
)
2928-
seen_ids.add(obj_id)
2930+
seen_refs.add(obj_id)
29292931
yield path, current
29302932
continue
29312933

29322934
obj_id = id(current)
2933-
if obj_id in seen_ids:
2935+
if obj_id in in_progress:
29342936
raise ValueError(
29352937
f'Found cycle at path "{"/".join(map(str, path))}". '
29362938
'Cycles are not supported with graph=False.'
29372939
)
2938-
seen_ids.add(obj_id)
2940+
in_progress.add(obj_id)
29392941

29402942
stack.append((path, current, True))
29412943
children, _ = jax.tree_util.tree_flatten_with_path(
@@ -2946,7 +2948,13 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29462948
stack.append(((*path, key), child, False))
29472949

29482950

2949-
def recursive_map(f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /):
2951+
def recursive_map(
2952+
f: tp.Callable[[PathParts, tp.Any], tp.Any],
2953+
node: tp.Any,
2954+
/,
2955+
*,
2956+
graph: bool | None = None,
2957+
):
29502958
"""Recursively applies a function to all nodes and leaves of the given graph node.
29512959
29522960
Example::
@@ -2969,15 +2977,28 @@ def recursive_map(f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /):
29692977
Path = .conv Conv
29702978
Path = .lin Linear
29712979
Path = . MyModule
2980+
2981+
Args:
2982+
f: A function that takes a path and a node and returns a new node.
2983+
node: A graph node object.
2984+
graph: If ``True`` (default), uses graph-mode which supports the full
2985+
NNX feature set including shared references. If ``False``, uses
2986+
tree-mode which treats Modules as regular JAX pytrees, avoiding
2987+
the overhead of the graph protocol.
29722988
"""
2973-
node = clone(node, variables=False)
2974-
path_parts: PathParts = ()
2975-
visited: set[int] = set()
2976-
results: dict[int, tp.Any] = {}
2977-
return _recursive_map(f, node, path_parts, visited, results)
2989+
if graph is None:
2990+
graph = config.flax_nnx_graph_mode
2991+
if graph:
2992+
node = clone(node, variables=False)
2993+
path_parts: PathParts = ()
2994+
visited: set[int] = set()
2995+
results: dict[int, tp.Any] = {}
2996+
return _recursive_map_graph(f, node, path_parts, visited, results)
2997+
else:
2998+
return _recursive_map_tree(f, node)
29782999

29793000

2980-
def _recursive_map(
3001+
def _recursive_map_graph(
29813002
f: tp.Callable[[PathParts, tp.Any], tp.Any],
29823003
node: tp.Any,
29833004
path: PathParts,
@@ -3002,7 +3023,7 @@ def _recursive_map(
30023023
visited.add(node_id)
30033024
if node_impl is not None:
30043025
for key, value in node_impl.node_dict(node).items():
3005-
new_value = _recursive_map(f, value, (*path, key), visited, results)
3026+
new_value = _recursive_map_graph(f, value, (*path, key), visited, results)
30063027
if new_value is not value:
30073028
if node_impl.set_key is not None and value is not new_value:
30083029
node_impl.set_key(node, key, new_value)
@@ -3017,6 +3038,52 @@ def _recursive_map(
30173038
return new_node
30183039

30193040

3041+
def _recursive_map_tree(
3042+
f: tp.Callable[[PathParts, tp.Any], tp.Any],
3043+
node: tp.Any,
3044+
) -> tp.Any:
3045+
in_progress: set[int] = set()
3046+
seen_refs: set[int] = set()
3047+
3048+
def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
3049+
if not is_pytree_node(current, check_graph_registry=False):
3050+
if isinstance(current, Variable) or is_array_ref(current):
3051+
obj_id = id(current)
3052+
if obj_id in seen_refs:
3053+
raise ValueError(
3054+
f'Found duplicate Variable or Ref at path '
3055+
f'"{"/".join(map(str, path))}". '
3056+
'Shared references are not supported with graph=False.'
3057+
)
3058+
seen_refs.add(obj_id)
3059+
return f(path, current)
3060+
3061+
obj_id = id(current)
3062+
if obj_id in in_progress:
3063+
raise ValueError(
3064+
f'Found cycle at path "{"/".join(map(str, path))}". '
3065+
'Cycles are not supported with graph=False.'
3066+
)
3067+
in_progress.add(obj_id)
3068+
3069+
children_with_path, treedef = jax.tree_util.tree_flatten_with_path(
3070+
current, is_leaf=lambda x: x is not current
3071+
)
3072+
new_children = []
3073+
for jax_key_path, child in children_with_path:
3074+
key = _key_path_to_key(jax_key_path[0])
3075+
new_child = _recurse((*path, key), child)
3076+
new_children.append(new_child)
3077+
3078+
new_node = treedef.unflatten(new_children)
3079+
result = f(path, new_node)
3080+
3081+
in_progress.discard(obj_id)
3082+
return result
3083+
3084+
return _recurse((), node)
3085+
3086+
30203087
def find_duplicates(node: tp.Any, /, *, only: filterlib.Filter = ...) -> list[list[PathParts]]:
30213088
"""Finds duplicate nodes or node leaves in the given node.
30223089

flax/nnx/module.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def eval(self, **attributes):
428428
raise_if_not_found=False,
429429
)
430430

431-
def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, **kwargs) -> A:
431+
def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, graph: bool | None = None, **kwargs) -> A:
432432
"""Creates a new node with static attributes updated according to ``**kwargs``.
433433
434434
The new node contains references to jax arrays in the original node. If a
@@ -462,6 +462,10 @@ def view(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool =
462462
Args:
463463
node: the object to create a copy of.
464464
only: Filters to select the Modules to set the attributes of.
465+
graph: If ``True`` (default), uses graph-mode which supports the full
466+
NNX feature set including shared references. If ``False``, uses
467+
tree-mode which treats Modules as regular JAX pytrees, avoiding
468+
the overhead of the graph protocol.
465469
**kwargs: The attributes to set.
466470
"""
467471
predicate = filterlib.to_predicate(only)
@@ -477,7 +481,7 @@ def _set_mode_fn(path, node):
477481
counts[k] += 1
478482
return node
479483

480-
out = graphlib.recursive_map(_set_mode_fn, node)
484+
out = graphlib.recursive_map(_set_mode_fn, node, graph=graph)
481485

482486
if raise_if_not_found:
483487
set_mode_calls = counts.pop("_set_mode_calls")
@@ -518,7 +522,7 @@ def _parse_docstring_args(doc_str: str) -> dict[str, str]:
518522

519523

520524

521-
def view_info(node: Module, /, *, only: filterlib.Filter = ...) -> str:
525+
def view_info(node: Module, /, *, only: filterlib.Filter = ..., graph: bool | None = None) -> str:
522526
"""Provides information about the ``view`` arguments for a module and all
523527
submodules. If no docstring is provided for a module's `set_view`, this function
524528
puts the `set_view` signature below the function.
@@ -554,6 +558,10 @@ def view_info(node: Module, /, *, only: filterlib.Filter = ...) -> str:
554558
Args:
555559
node: the object to display ``view`` information for.
556560
only: Filters to select the Modules to display information for.
561+
graph: If ``True`` (default), uses graph-mode which supports the full
562+
NNX feature set including shared references. If ``False``, uses
563+
tree-mode which treats Modules as regular JAX pytrees, avoiding
564+
the overhead of the graph protocol.
557565
"""
558566
predicate = filterlib.to_predicate(only)
559567
classes: set[Module] = set()
@@ -563,7 +571,7 @@ def _set_mode_info_fn(path, node):
563571
classes.add(node.__class__)
564572
return node
565573

566-
graphlib.recursive_map(_set_mode_info_fn, node)
574+
graphlib.recursive_map(_set_mode_info_fn, node, graph=graph)
567575

568576
class_list = sorted(list(classes), key=lambda x: x.__qualname__)
569577
out_str = []

flax/nnx/rnglib.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -896,31 +896,31 @@ def _tree_split_rngs(
896896
only: filterlib.Filter = ...,
897897
squeeze: bool = False,
898898
) -> tp.Any:
899-
node = graphlib.clone(node, graph=False)
900899
predicate = filterlib.to_predicate(only)
901-
for path, stream in graphlib.iter_graph(node, graph=False):
900+
901+
def _split_stream(path, node):
902902
if (
903-
isinstance(stream, RngStream)
904-
and predicate((*path, 'key'), stream.key)
905-
and predicate((*path, 'count'), stream.count)
903+
isinstance(node, RngStream)
904+
and predicate((*path, 'key'), node.key)
905+
and predicate((*path, 'count'), node.count)
906906
):
907-
key = stream()
908-
key = random.split(key, splits)
907+
key = random.split(node(), splits)
909908
if squeeze:
910909
key = key[0]
911910
if squeeze:
912-
counts_shape = stream.count.shape
911+
counts_shape = node.count.shape
913912
elif isinstance(splits, int):
914-
counts_shape = (splits, *stream.count.shape)
913+
counts_shape = (splits, *node.count.shape)
915914
else:
916-
counts_shape = (*splits, *stream.count.shape)
915+
counts_shape = (*splits, *node.count.shape)
917916

918-
stream.key = RngKey(key, tag=stream.tag)
919-
stream.count = RngCount(
920-
jnp.zeros(counts_shape, dtype=jnp.uint32), tag=stream.tag
917+
node.key = RngKey(key, tag=node.tag)
918+
node.count = RngCount(
919+
jnp.zeros(counts_shape, dtype=jnp.uint32), tag=node.tag
921920
)
921+
return node
922922

923-
return node
923+
return graphlib.recursive_map(_split_stream, node, graph=False)
924924

925925
@tp.overload
926926
def fork_rngs(

tests/nnx/graph_utils_test.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,8 @@ def swap(path, node):
12091209
self.assertEqual(bar2[1].d, -20)
12101210
self.assertEqual(n, 2)
12111211

1212-
def test_recursive_map_with_list(self):
1212+
@parameterized.parameters(True, False)
1213+
def test_recursive_map_with_list(self, graph):
12131214
rngs = nnx.Rngs(0)
12141215
model = nnx.Sequential(nnx.Linear(2, 3, rngs=rngs), nnx.relu, nnx.Linear(3, 4, rngs=rngs))
12151216

@@ -1218,7 +1219,7 @@ def add_rank2_lora(_, node):
12181219
return nnx.LoRA(node.in_features, 2, node.out_features, base_module=node, rngs=rngs)
12191220
return node
12201221

1221-
self.assertEqual(len(nnx.recursive_map(add_rank2_lora, model).layers), 3)
1222+
self.assertEqual(len(nnx.recursive_map(add_rank2_lora, model, graph=graph).layers), 3)
12221223

12231224
def test_graphdef_hash_with_sequential(self):
12241225
rngs = nnx.Rngs(0)
@@ -1503,6 +1504,86 @@ def __init__(self, rngs):
15031504
self.assertIn('Dropout', module_types)
15041505
self.assertLen(modules, 3)
15051506

1507+
def test_recursive_map_tree_mode(self):
1508+
class Foo(nnx.Pytree):
1509+
def __init__(self, d):
1510+
self.d = d
1511+
1512+
foo1 = Foo(10)
1513+
foo2 = Foo(20)
1514+
bar = [foo1, foo2]
1515+
n = 0
1516+
1517+
def inc_d(path, node):
1518+
nonlocal n
1519+
if isinstance(node, Foo):
1520+
n += 1
1521+
node.d += 1
1522+
return node
1523+
1524+
bar2 = nnx.recursive_map(inc_d, bar, graph=False)
1525+
self.assertEqual(bar2[0].d, 11)
1526+
self.assertEqual(bar2[1].d, 21)
1527+
self.assertEqual(n, 2)
1528+
1529+
def test_recursive_map_tree_mode_replace(self):
1530+
class Foo(nnx.Pytree):
1531+
def __init__(self, d):
1532+
self.d = d
1533+
1534+
foo1 = Foo(10)
1535+
foo2 = Foo(20)
1536+
bar = [foo1, foo2]
1537+
n = 0
1538+
1539+
def swap(path, node):
1540+
nonlocal n
1541+
if isinstance(node, Foo):
1542+
n += 1
1543+
node = Foo(-node.d)
1544+
return node
1545+
1546+
bar2 = nnx.recursive_map(swap, bar, graph=False)
1547+
self.assertEqual(bar2[0].d, -10)
1548+
self.assertEqual(bar2[1].d, -20)
1549+
self.assertEqual(n, 2)
1550+
1551+
def test_recursive_map_tree_mode_with_list(self):
1552+
rngs = nnx.Rngs(0)
1553+
model = nnx.Sequential(
1554+
nnx.Linear(2, 3, rngs=rngs), nnx.relu, nnx.Linear(3, 4, rngs=rngs)
1555+
)
1556+
1557+
def add_rank2_lora(_, node):
1558+
if isinstance(node, nnx.Linear):
1559+
return nnx.LoRA(
1560+
node.in_features, 2, node.out_features,
1561+
base_module=node, rngs=rngs,
1562+
)
1563+
return node
1564+
1565+
result = nnx.recursive_map(add_rank2_lora, model, graph=False)
1566+
self.assertLen(result.layers, 3)
1567+
1568+
def test_recursive_map_tree_mode_shared_variable_raises(self):
1569+
v = nnx.Param(jnp.array(1))
1570+
g = [v, v]
1571+
1572+
with self.assertRaisesRegex(
1573+
ValueError, 'Shared references are not supported with graph=False'
1574+
):
1575+
nnx.recursive_map(lambda path, node: node, g, graph=False)
1576+
1577+
def test_recursive_map_tree_mode_cycle_raises(self):
1578+
a = nnx.List([1])
1579+
b = nnx.List([2, a])
1580+
a.append(b)
1581+
1582+
with self.assertRaisesRegex(
1583+
ValueError, 'Cycles are not supported with graph=False'
1584+
):
1585+
nnx.recursive_map(lambda path, node: node, a, graph=False)
1586+
15061587

15071588
if __name__ == '__main__':
15081589
absltest.main()

0 commit comments

Comments
 (0)