Skip to content

Commit 8f8ec30

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Add tree-mode support to iter_graph
Add a graph: bool | None parameter to `iter_graph` and other downstream functions such as `iter_modules`, `split_rngs`, `fork_rngs`, `backup_keys`, and `reseed`. PiperOrigin-RevId: 873211727
1 parent 2f6b04c commit 8f8ec30

File tree

9 files changed

+313
-99
lines changed

9 files changed

+313
-99
lines changed

flax/nnx/bridge/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _maybe_call_setup(module: Module):
104104
def _bind_module(parent: Module, module: Module) -> Module:
105105
assert parent.scope is not None
106106

107-
for _, value in reversed(list(graph.iter_graph(module))):
107+
for _, value in reversed(list(graph.iter_graph(module, graph=True))):
108108
if isinstance(value, Module):
109109
if module.scope is None:
110110
value.scope = parent.scope.copy() # type: ignore[attribute-error]
@@ -471,7 +471,7 @@ def to_variable(value):
471471
_method = _get_unbound_fn(_method)
472472

473473
# set temporary state
474-
for _, value in graph.iter_graph(module):
474+
for _, value in graph.iter_graph(module, graph=True):
475475
if isinstance(value, Pytree):
476476
value._pytree__state._initializing = _initialize
477477
if isinstance(value, Module):
@@ -486,7 +486,7 @@ def to_variable(value):
486486
finally:
487487
MODULE_CONTEXT.module_stack.pop()
488488
# reset temporary state
489-
for _, value in graph.iter_graph(module):
489+
for _, value in graph.iter_graph(module, graph=True):
490490
if isinstance(value, Pytree):
491491
value._pytree__state._initializing = False
492492
if isinstance(value, Module):

flax/nnx/bridge/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _functional_constructor(*args: tp.Any, **kwargs: tp.Any) -> Functional[M]:
6666

6767

6868
def _set_initializing(module: Module, initializing: bool):
69-
for _, value in graph.iter_graph(module):
69+
for _, value in graph.iter_graph(module, graph=True):
7070
if isinstance(value, Pytree):
7171
value._pytree__state._initializing = initializing
7272

flax/nnx/extract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def check_consistent_aliasing(
5757
node_id_to_variable: dict[int, tp.Any] = {}
5858

5959
# collect all paths and prefixes for each node
60-
for path, value in graph.iter_graph(node):
60+
for path, value in graph.iter_graph(node, graph=True):
6161
if graph.is_graph_node(value) or isinstance(value, graph.Variable):
6262
if isinstance(value, Pytree):
6363
value._check_valid_context(

flax/nnx/graph.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2834,7 +2834,9 @@ def _set_metadata(path: PathParts, variable: V) -> None:
28342834
map_state(_set_metadata, state(node, only))
28352835

28362836

2837-
def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
2837+
def iter_graph(
2838+
node: tp.Any, /, *, graph: bool | None = None,
2839+
) -> tp.Iterator[tuple[PathParts, tp.Any]]:
28382840
"""Iterates over all nested nodes and leaves of the given graph node, including the current node.
28392841
28402842
``iter_graph`` creates a generator that yields path and value pairs, where the
@@ -2866,22 +2868,35 @@ def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
28662868
(0, 'w') Param
28672869
(0,) Linear
28682870
() list
2871+
2872+
Args:
2873+
node: A graph node object.
2874+
graph: If ``True`` (default), uses graph-mode which supports the full
2875+
NNX feature set including shared references. If ``False``, uses
2876+
tree-mode which treats Modules as regular JAX pytrees, avoiding
2877+
the overhead of the graph protocol.
28692878
"""
2879+
if graph is None:
2880+
graph = config.flax_nnx_graph_mode
2881+
if graph:
2882+
return _iter_graph(node)
2883+
else:
2884+
return _iter_tree(node)
2885+
2886+
2887+
def _iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
28702888
visited: set[int] = set()
28712889
stack: list[tuple[PathParts, tp.Any, bool]] = [((), node, False)]
28722890
while stack:
2873-
# Yield if the node is either a leaf or has been traversed already.
28742891
path_parts, node, traversed = stack.pop(-1)
28752892
if traversed or not (is_node(node) or isinstance(node, Variable)):
28762893
yield path_parts, node
28772894
continue
28782895

2879-
# Skip if the node has been visited already.
28802896
if id(node) in visited:
28812897
continue
28822898
visited.add(id(node))
28832899

2884-
# Traverse the node.
28852900
if (node_impl := get_node_impl(node)) is None:
28862901
yield path_parts, node
28872902
continue
@@ -2891,6 +2906,46 @@ def iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
28912906
stack.append(((*path_parts, key), child, False))
28922907

28932908

2909+
def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
2910+
seen_ids: set[int] = set()
2911+
stack: list[tuple[PathParts, tp.Any, bool]] = [((), node, False)]
2912+
while stack:
2913+
path, current, traversed = stack.pop()
2914+
2915+
if traversed:
2916+
yield path, current
2917+
continue
2918+
2919+
if not is_pytree_node(current, check_graph_registry=False):
2920+
if isinstance(current, Variable) or variablelib.is_array_ref(current):
2921+
obj_id = id(current)
2922+
if obj_id in seen_ids:
2923+
raise ValueError(
2924+
f'Found duplicate Variable or Ref at path '
2925+
f'"{"/".join(map(str, path))}". '
2926+
'Shared references are not supported with graph=False.'
2927+
)
2928+
seen_ids.add(obj_id)
2929+
yield path, current
2930+
continue
2931+
2932+
obj_id = id(current)
2933+
if obj_id in seen_ids:
2934+
raise ValueError(
2935+
f'Found cycle at path "{"/".join(map(str, path))}". '
2936+
'Cycles are not supported with graph=False.'
2937+
)
2938+
seen_ids.add(obj_id)
2939+
2940+
stack.append((path, current, True))
2941+
children, _ = jax.tree_util.tree_flatten_with_path(
2942+
current, is_leaf=lambda x: x is not current
2943+
)
2944+
for jax_key_path, child in reversed(children):
2945+
key = _key_path_to_key(jax_key_path[0])
2946+
stack.append(((*path, key), child, False))
2947+
2948+
28942949
def recursive_map(f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, /):
28952950
"""Recursively applies a function to all nodes and leaves of the given graph node.
28962951
@@ -3063,8 +3118,10 @@ class GenericPytree: ...
30633118
from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY
30643119

30653120

3066-
def is_pytree_node(x: tp.Any) -> bool:
3067-
if type(x) in GRAPH_REGISTRY:
3121+
def is_pytree_node(
3122+
x: tp.Any, *, check_graph_registry: bool = True,
3123+
) -> bool:
3124+
if check_graph_registry and type(x) in GRAPH_REGISTRY:
30683125
return False
30693126
elif isinstance(x, Variable):
30703127
return False

flax/nnx/module.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from flax.nnx import (
2424
filterlib,
25-
graph,
25+
graph as graphlib,
2626
)
2727
from flax.nnx import variablelib as variableslib
2828
from flax.nnx.pytreelib import Pytree, PytreeMeta
@@ -477,7 +477,7 @@ def _set_mode_fn(path, node):
477477
counts[k] += 1
478478
return node
479479

480-
out = graph.recursive_map(_set_mode_fn, node)
480+
out = graphlib.recursive_map(_set_mode_fn, node)
481481

482482
if raise_if_not_found:
483483
set_mode_calls = counts.pop("_set_mode_calls")
@@ -563,7 +563,7 @@ def _set_mode_info_fn(path, node):
563563
classes.add(node.__class__)
564564
return node
565565

566-
graph.recursive_map(_set_mode_info_fn, node)
566+
graphlib.recursive_map(_set_mode_info_fn, node)
567567

568568
class_list = sorted(list(classes), key=lambda x: x.__qualname__)
569569
out_str = []
@@ -613,7 +613,9 @@ def first_from(*args: tp.Optional[A], error_msg: str) -> A:
613613
return arg
614614
raise ValueError(error_msg)
615615

616-
def iter_modules(module: Module) -> tp.Iterator[tuple[PathParts, Module]]:
616+
def iter_modules(
617+
module: Module, /, *, graph: bool | None = None,
618+
) -> tp.Iterator[tuple[PathParts, Module]]:
617619
"""Recursively iterates over all nested :class:`Module`'s of the given Module, including
618620
the argument.
619621
@@ -648,8 +650,15 @@ def iter_modules(module: Module) -> tp.Iterator[tuple[PathParts, Module]]:
648650
('submodule', 'linear2') Linear
649651
('submodule',) SubModule
650652
() Block
653+
654+
Args:
655+
module: A :class:`Module` object.
656+
graph: If ``True`` (default), uses graph-mode which supports the full
657+
NNX feature set including shared references. If ``False``, uses
658+
tree-mode which treats Modules as regular JAX pytrees, avoiding
659+
the overhead of the graph protocol.
651660
"""
652-
for path, value in graph.iter_graph(module):
661+
for path, value in graphlib.iter_graph(module, graph=graph):
653662
if isinstance(value, Module):
654663
yield path, value
655664

@@ -687,7 +696,7 @@ def iter_children(module: Module) -> tp.Iterator[tuple[Key, Module]]:
687696
linear Linear
688697
submodule SubModule
689698
"""
690-
node_impl = graph.get_node_impl(module)
699+
node_impl = graphlib.get_node_impl(module)
691700
assert node_impl is not None
692701
node_dict = node_impl.node_dict(module)
693702
for key, value in node_dict.items():

0 commit comments

Comments
 (0)