Skip to content

Commit a609389

Browse files
hyunn9973Flax Authors
authored andcommitted
Add iter_module_children() to iterate over module children only.
PiperOrigin-RevId: 878177716
1 parent 15683fd commit a609389

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

flax/nnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from .module import view as view
5454
from .module import view_info as view_info
5555
from .module import with_attributes as with_attributes
56-
from .module import iter_children as iter_children, iter_modules as iter_modules
56+
from .module import iter_children as iter_children, iter_modules as iter_modules, iter_module_children as iter_module_children
5757
from .graphlib import merge as merge
5858
from .graphlib import UpdateContext as UpdateContext
5959
from .graphlib import update_context as update_context

flax/nnx/graphlib.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,10 @@ def is_node_type(x: type[tp.Any]) -> bool:
318318
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
319319

320320

321+
def is_node_module(x: tp.Any) -> bool:
322+
return type(x) in GRAPH_REGISTRY
323+
324+
321325
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
322326
if isinstance(x, Variable):
323327
return None
@@ -346,6 +350,7 @@ def get_node_impl_for_type(
346350
else:
347351
return None
348352

353+
349354
# use type-aware sorting to support int keys
350355
def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
351356
key, _ = item
@@ -593,6 +598,7 @@ class NodeAttr:
593598

594599
NODE_ATTR = NodeAttr()
595600

601+
596602
@dataclasses.dataclass(frozen=True, slots=True)
597603
class LeafAttr:
598604
pass
@@ -857,6 +863,7 @@ def flatten( # type: ignore[invalid-annotation]
857863
else:
858864
return graphdef, leaves
859865

866+
860867
@dataclasses.dataclass(frozen=True, slots=True)
861868
class DataElem:
862869
value: tp.Any
@@ -866,6 +873,7 @@ class DataElem:
866873
class StaticElem:
867874
value: tp.Any
868875

876+
869877
def _graph_flatten(
870878
node: Node,
871879
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
@@ -3085,6 +3093,63 @@ def iter_children(
30853093
yield key, child
30863094

30873095

3096+
def iter_module_children(
3097+
node: tp.Any, /, *, graph: bool | None = None,
3098+
) -> tp.Iterator[tuple[Key, tp.Any]]:
3099+
"""Iterates over all module children of a given node. This function is similar
3100+
to :func:`iter_children`, except it only iterates over the module children
3101+
only.
3102+
3103+
Example::
3104+
3105+
>>> from flax import nnx
3106+
...
3107+
>>> model = nnx.Linear(2, 5, rngs=nnx.Rngs(0))
3108+
>>> for path, module in nnx.iter_module_children(model):
3109+
... print(path, type(module).__name__)
3110+
...
3111+
>>> for path, module in nnx.iter_children(model):
3112+
... print(path, type(module).__name__)
3113+
...
3114+
bias Param
3115+
kernel Param
3116+
3117+
Args:
3118+
node: A graph node object.
3119+
graph: If ``True`` (default), uses graph-mode which supports the full
3120+
NNX feature set including shared references. If ``False``, uses
3121+
tree-mode which treats Modules as regular JAX pytrees, avoiding
3122+
the overhead of the graph protocol.
3123+
"""
3124+
if graph is None:
3125+
graph = set_graph_mode.current_value()
3126+
if graph:
3127+
node_impl = get_node_impl(node)
3128+
if node_impl is None:
3129+
raise ValueError(
3130+
f'Expected a graph node, got {type(node).__name__}. '
3131+
'If this is a regular pytree, use graph=False.'
3132+
)
3133+
node_dict = node_impl.node_dict(node)
3134+
for key, value in node_dict.items():
3135+
if is_node_module(value):
3136+
yield key, value
3137+
else:
3138+
_check_valid_pytree(node, 'iter_children')
3139+
if not is_pytree_node(node, check_graph_registry=False):
3140+
raise ValueError(
3141+
f'Expected a pytree node, got {type(node).__name__}. '
3142+
'If this is a graph node, use graph=True.'
3143+
)
3144+
children, _ = jax.tree_util.tree_flatten_with_path(
3145+
node, is_leaf=lambda x: x is not node
3146+
)
3147+
for jax_key_path, child in children:
3148+
if is_node_module(child):
3149+
key = _key_path_to_key(jax_key_path[0])
3150+
yield key, child
3151+
3152+
30883153
def recursive_map(
30893154
f: tp.Callable[[PathParts, tp.Any], tp.Any],
30903155
node: tp.Any,

flax/nnx/module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ def iter_modules(
766766
yield path, value
767767

768768
iter_children = graphlib.iter_children
769+
iter_module_children = graphlib.iter_module_children
769770

770771
P = tp.ParamSpec("P")
771772
R = tp.TypeVar("R")

0 commit comments

Comments
 (0)