Skip to content

Commit 6a62992

Browse files
hyunn9973Flax Authors
authored andcommitted
Add iter_module_children() to iterate over module children only.
PiperOrigin-RevId: 878177716
1 parent 1779706 commit 6a62992

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

flax/nnx/graphlib.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ def is_node_type(x: type[tp.Any]) -> bool:
296296
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
297297

298298

299+
def is_node_module(x: tp.Any) -> bool:
300+
return type(x) in GRAPH_REGISTRY
301+
302+
299303
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any] | None:
300304
if isinstance(x, Variable):
301305
return None
@@ -324,6 +328,7 @@ def get_node_impl_for_type(
324328
else:
325329
return None
326330

331+
327332
# use type-aware sorting to support int keys
328333
def _type_aware_sort(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
329334
key, _ = item
@@ -571,6 +576,7 @@ class NodeAttr:
571576

572577
NODE_ATTR = NodeAttr()
573578

579+
574580
@dataclasses.dataclass(frozen=True, slots=True)
575581
class LeafAttr:
576582
pass
@@ -839,6 +845,7 @@ def flatten( # type: ignore[invalid-annotation]
839845
else:
840846
return graphdef, leaves
841847

848+
842849
@dataclasses.dataclass(frozen=True, slots=True)
843850
class DataElem:
844851
value: tp.Any
@@ -848,6 +855,7 @@ class DataElem:
848855
class StaticElem:
849856
value: tp.Any
850857

858+
851859
def _graph_flatten(
852860
node: Node,
853861
node_impl: NodeImpl[Node, Leaf, AuxData] | None,
@@ -3062,6 +3070,62 @@ def iter_children(
30623070
yield key, child
30633071

30643072

3073+
def iter_module_children(
3074+
node: tp.Any, /, *, graph: bool | None = None,
3075+
) -> tp.Iterator[tuple[Key, tp.Any]]:
3076+
"""Iterates over all module children of a given node. This function is similar
3077+
to :func:`iter_children`, except it only iterates over the module children
3078+
only.
3079+
3080+
Example::
3081+
3082+
>>> from flax import nnx
3083+
...
3084+
>>> model = nnx.Linear(2, 5, rngs=nnx.Rngs(0))
3085+
>>> for path, module in nnx.iter_module_children(model):
3086+
... print(path, type(module).__name__)
3087+
...
3088+
>>> for path, module in nnx.iter_children(model):
3089+
... print(path, type(module).__name__)
3090+
...
3091+
kernel Param
3092+
3093+
Args:
3094+
node: A graph node object.
3095+
graph: If ``True`` (default), uses graph-mode which supports the full
3096+
NNX feature set including shared references. If ``False``, uses
3097+
tree-mode which treats Modules as regular JAX pytrees, avoiding
3098+
the overhead of the graph protocol.
3099+
"""
3100+
if graph is None:
3101+
graph = set_graph_mode.current_value()
3102+
if graph:
3103+
node_impl = get_node_impl(node)
3104+
if node_impl is None:
3105+
raise ValueError(
3106+
f'Expected a graph node, got {type(node).__name__}. '
3107+
'If this is a regular pytree, use graph=False.'
3108+
)
3109+
node_dict = node_impl.node_dict(node)
3110+
for key, value in node_dict.items():
3111+
if is_node_module(value):
3112+
yield key, value
3113+
else:
3114+
_check_valid_pytree(node, 'iter_children')
3115+
if not is_pytree_node(node, check_graph_registry=False):
3116+
raise ValueError(
3117+
f'Expected a pytree node, got {type(node).__name__}. '
3118+
'If this is a graph node, use graph=True.'
3119+
)
3120+
children, _ = jax.tree_util.tree_flatten_with_path(
3121+
node, is_leaf=lambda x: x is not node
3122+
)
3123+
for jax_key_path, child in children:
3124+
if is_node_module(child):
3125+
key = _key_path_to_key(jax_key_path[0])
3126+
yield key, child
3127+
3128+
30653129
def recursive_map(
30663130
f: tp.Callable[[PathParts, tp.Any], tp.Any],
30673131
node: tp.Any,

0 commit comments

Comments
 (0)