@@ -318,6 +318,10 @@ def is_node_type(x: type[tp.Any]) -> bool:
318318 return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
319319
320320
321+ def is_node_module (x : tp .Any ) -> bool :
322+ return type (x ) in GRAPH_REGISTRY
323+
324+
321325def get_node_impl (x : Node ) -> NodeImpl [Node , tp .Any , tp .Any ] | None :
322326 if isinstance (x , Variable ):
323327 return None
@@ -346,6 +350,7 @@ def get_node_impl_for_type(
346350 else :
347351 return None
348352
353+
349354# use type-aware sorting to support int keys
350355def _type_aware_sort (item : tuple [tp .Any , tp .Any ]) -> tuple [int , tp .Any ]:
351356 key , _ = item
@@ -593,6 +598,7 @@ class NodeAttr:
593598
594599NODE_ATTR = NodeAttr ()
595600
601+
596602@dataclasses .dataclass (frozen = True , slots = True )
597603class LeafAttr :
598604 pass
@@ -857,6 +863,7 @@ def flatten( # type: ignore[invalid-annotation]
857863 else :
858864 return graphdef , leaves
859865
866+
860867@dataclasses .dataclass (frozen = True , slots = True )
861868class DataElem :
862869 value : tp .Any
@@ -866,6 +873,7 @@ class DataElem:
866873class StaticElem :
867874 value : tp .Any
868875
876+
869877def _graph_flatten (
870878 node : Node ,
871879 node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
@@ -3085,6 +3093,63 @@ def iter_children(
30853093 yield key , child
30863094
30873095
3096+ def iter_module_children (
3097+ node : tp .Any , / , * , graph : bool | None = None ,
3098+ ) -> tp .Iterator [tuple [Key , tp .Any ]]:
3099+ """Iterates over all module children of a given node. This function is similar
3100+ to :func:`iter_children`, except it only iterates over the module children
3101+ only.
3102+
3103+ Example::
3104+
3105+ >>> from flax import nnx
3106+ ...
3107+ >>> model = nnx.Linear(2, 5, rngs=nnx.Rngs(0))
3108+ >>> for path, module in nnx.iter_module_children(model):
3109+ ... print(path, type(module).__name__)
3110+ ...
3111+ >>> for path, module in nnx.iter_children(model):
3112+ ... print(path, type(module).__name__)
3113+ ...
3114+ bias Param
3115+ kernel Param
3116+
3117+ Args:
3118+ node: A graph node object.
3119+ graph: If ``True`` (default), uses graph-mode which supports the full
3120+ NNX feature set including shared references. If ``False``, uses
3121+ tree-mode which treats Modules as regular JAX pytrees, avoiding
3122+ the overhead of the graph protocol.
3123+ """
3124+ if graph is None :
3125+ graph = set_graph_mode .current_value ()
3126+ if graph :
3127+ node_impl = get_node_impl (node )
3128+ if node_impl is None :
3129+ raise ValueError (
3130+ f'Expected a graph node, got { type (node ).__name__ } . '
3131+ 'If this is a regular pytree, use graph=False.'
3132+ )
3133+ node_dict = node_impl .node_dict (node )
3134+ for key , value in node_dict .items ():
3135+ if is_node_module (value ):
3136+ yield key , value
3137+ else :
3138+ _check_valid_pytree (node , 'iter_children' )
3139+ if not is_pytree_node (node , check_graph_registry = False ):
3140+ raise ValueError (
3141+ f'Expected a pytree node, got { type (node ).__name__ } . '
3142+ 'If this is a graph node, use graph=True.'
3143+ )
3144+ children , _ = jax .tree_util .tree_flatten_with_path (
3145+ node , is_leaf = lambda x : x is not node
3146+ )
3147+ for jax_key_path , child in children :
3148+ if is_node_module (child ):
3149+ key = _key_path_to_key (jax_key_path [0 ])
3150+ yield key , child
3151+
3152+
30883153def recursive_map (
30893154 f : tp .Callable [[PathParts , tp .Any ], tp .Any ],
30903155 node : tp .Any ,
0 commit comments