Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from .module import view as view
from .module import view_info as view_info
from .module import with_attributes as with_attributes
from .module import iter_children as iter_children, iter_modules as iter_modules
from .module import iter_children as iter_children, iter_modules as iter_modules, iter_module_children as iter_module_children
from .graphlib import merge as merge
from .graphlib import UpdateContext as UpdateContext
from .graphlib import update_context as update_context
Expand Down
65 changes: 65 additions & 0 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def is_node_type(x: type[tp.Any]) -> bool:
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree


def is_node_module(x: tp.Any) -> bool:
return type(x) in GRAPH_REGISTRY


def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
if isinstance(x, Variable):
return None
Expand Down Expand Up @@ -347,6 +351,7 @@ def get_node_impl_for_type(
else:
return None


# use type-aware sorting to support int keys
def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
key, _ = item
Expand Down Expand Up @@ -594,6 +599,7 @@ class NodeAttr:

NODE_ATTR = NodeAttr()


@dataclasses.dataclass(frozen=True, slots=True)
class LeafAttr:
pass
Expand Down Expand Up @@ -858,6 +864,7 @@ def flatten( # type: ignore[invalid-annotation]
else:
return graphdef, leaves


@dataclasses.dataclass(frozen=True, slots=True)
class DataElem:
value: tp.Any
Expand All @@ -867,6 +874,7 @@ class DataElem:
class StaticElem:
value: tp.Any


def _graph_flatten(
node: Node,
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
Expand Down Expand Up @@ -3128,6 +3136,63 @@ def iter_children(
yield key, child


def iter_module_children(
node: tp.Any, /, *, graph: bool | None = None,
) -> tp.Iterator[tuple[Key, tp.Any]]:
"""Iterates over all module children of a given node. This function is similar
to :func:`iter_children`, except it only iterates over the module children
only.

Example::

>>> from flax import nnx
...
>>> model = nnx.Linear(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in nnx.iter_module_children(model):
... print(path, type(module).__name__)
...
>>> for path, module in nnx.iter_children(model):
... print(path, type(module).__name__)
...
bias Param
kernel Param

Args:
node: A graph node object.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references. If ``False``, uses
tree-mode which treats Modules as regular JAX pytrees, avoiding
the overhead of the graph protocol.
"""
if graph is None:
graph = set_graph_mode.current_value()
if graph:
node_impl = get_node_impl(node)
if node_impl is None:
raise ValueError(
f'Expected a graph node, got {type(node).__name__}. '
'If this is a regular pytree, use graph=False.'
)
node_dict = node_impl.node_dict(node)
for key, value in node_dict.items():
if is_node_module(value):
yield key, value
else:
_check_valid_pytree(node, 'iter_children')
if not is_pytree_node(node, check_graph_registry=False):
raise ValueError(
f'Expected a pytree node, got {type(node).__name__}. '
'If this is a graph node, use graph=True.'
)
children, _ = jax.tree_util.tree_flatten_with_path(
node, is_leaf=lambda x: x is not node
)
for jax_key_path, child in children:
if is_node_module(child):
key = _key_path_to_key(jax_key_path[0])
yield key, child


def recursive_map(
f: tp.Callable[[PathParts, tp.Any], tp.Any],
node: tp.Any,
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ def iter_modules(
yield path, value

iter_children = graphlib.iter_children
iter_module_children = graphlib.iter_module_children

P = tp.ParamSpec("P")
R = tp.TypeVar("R")
Expand Down
Loading