diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index ac0fd2391..9af38941e 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -53,7 +53,7 @@ from .module import view as view from .module import view_info as view_info from .module import with_attributes as with_attributes -from .module import iter_children as iter_children, iter_modules as iter_modules +from .module import iter_children as iter_children, iter_modules as iter_modules, iter_module_children as iter_module_children from .graphlib import merge as merge from .graphlib import UpdateContext as UpdateContext from .graphlib import update_context as update_context diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index ed1095dbb..15e0ecc8a 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -319,6 +319,10 @@ def is_node_type(x: type[tp.Any]) -> bool: return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree +def is_node_module(x: tp.Any) -> bool: + return type(x) in GRAPH_REGISTRY + + def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None: if isinstance(x, Variable): return None @@ -347,6 +351,7 @@ def get_node_impl_for_type( else: return None + # use type-aware sorting to support int keys def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: key, _ = item @@ -594,6 +599,7 @@ class NodeAttr: NODE_ATTR = NodeAttr() + @dataclasses.dataclass(frozen=True, slots=True) class LeafAttr: pass @@ -858,6 +864,7 @@ def flatten( # type: ignore[invalid-annotation] else: return graphdef, leaves + @dataclasses.dataclass(frozen=True, slots=True) class DataElem: value: tp.Any @@ -867,6 +874,7 @@ class DataElem: class StaticElem: value: tp.Any + def _graph_flatten( node: Node, node_impl: NodeImpl[Node, Leaf, AuxData] | None, @@ -3128,6 +3136,63 @@ def iter_children( yield key, child +def iter_module_children( + node: tp.Any, /, *, graph: bool | None = None, +) -> tp.Iterator[tuple[Key, tp.Any]]: + """Iterates over all module children of a given node. This function is similar + to :func:`iter_children`, except it only iterates over the module children + only. + + Example:: + + >>> from flax import nnx + ... + >>> model = nnx.Linear(2, 5, rngs=nnx.Rngs(0)) + >>> for path, module in nnx.iter_module_children(model): + ... print(path, type(module).__name__) + ... + >>> for path, module in nnx.iter_children(model): + ... print(path, type(module).__name__) + ... + bias Param + kernel Param + + Args: + node: A graph node object. + graph: If ``True`` (default), uses graph-mode which supports the full + NNX feature set including shared references. If ``False``, uses + tree-mode which treats Modules as regular JAX pytrees, avoiding + the overhead of the graph protocol. + """ + if graph is None: + graph = set_graph_mode.current_value() + if graph: + node_impl = get_node_impl(node) + if node_impl is None: + raise ValueError( + f'Expected a graph node, got {type(node).__name__}. ' + 'If this is a regular pytree, use graph=False.' + ) + node_dict = node_impl.node_dict(node) + for key, value in node_dict.items(): + if is_node_module(value): + yield key, value + else: + _check_valid_pytree(node, 'iter_children') + if not is_pytree_node(node, check_graph_registry=False): + raise ValueError( + f'Expected a pytree node, got {type(node).__name__}. ' + 'If this is a graph node, use graph=True.' + ) + children, _ = jax.tree_util.tree_flatten_with_path( + node, is_leaf=lambda x: x is not node + ) + for jax_key_path, child in children: + if is_node_module(child): + key = _key_path_to_key(jax_key_path[0]) + yield key, child + + def recursive_map( f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, diff --git a/flax/nnx/module.py b/flax/nnx/module.py index a0eeb857c..a9af10c10 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -766,6 +766,7 @@ def iter_modules( yield path, value iter_children = graphlib.iter_children +iter_module_children = graphlib.iter_module_children P = tp.ParamSpec("P") R = tp.TypeVar("R")