@@ -319,6 +319,10 @@ def is_node_type(x: type[tp.Any]) -> bool:
319319 return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
320320
321321
322+ def is_node_module (x : tp .Any ) -> bool :
323+ return type (x ) in GRAPH_REGISTRY
324+
325+
322326def get_node_impl (x : Node ) -> NodeImpl [Node , tp .Any , tp .Any ] | None :
323327 if isinstance (x , Variable ):
324328 return None
@@ -347,6 +351,7 @@ def get_node_impl_for_type(
347351 else :
348352 return None
349353
354+
350355# use type-aware sorting to support int keys
351356def _type_aware_sort (item : tuple [tp .Any , tp .Any ]) -> tuple [int , tp .Any ]:
352357 key , _ = item
@@ -594,6 +599,7 @@ class NodeAttr:
594599
595600NODE_ATTR = NodeAttr ()
596601
602+
597603@dataclasses .dataclass (frozen = True , slots = True )
598604class LeafAttr :
599605 pass
@@ -858,6 +864,7 @@ def flatten( # type: ignore[invalid-annotation]
858864 else :
859865 return graphdef , leaves
860866
867+
861868@dataclasses .dataclass (frozen = True , slots = True )
862869class DataElem :
863870 value : tp .Any
@@ -867,6 +874,7 @@ class DataElem:
867874class StaticElem :
868875 value : tp .Any
869876
877+
870878def _graph_flatten (
871879 node : Node ,
872880 node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
@@ -3128,6 +3136,63 @@ def iter_children(
31283136 yield key , child
31293137
31303138
3139+ def iter_module_children (
3140+ node : tp .Any , / , * , graph : bool | None = None ,
3141+ ) -> tp .Iterator [tuple [Key , tp .Any ]]:
3142+ """Iterates over all module children of a given node. This function is similar
3143+ to :func:`iter_children`, except it only iterates over the module children
3144+ only.
3145+
3146+ Example::
3147+
3148+ >>> from flax import nnx
3149+ ...
3150+ >>> model = nnx.Linear(2, 5, rngs=nnx.Rngs(0))
3151+ >>> for path, module in nnx.iter_module_children(model):
3152+ ... print(path, type(module).__name__)
3153+ ...
3154+ >>> for path, module in nnx.iter_children(model):
3155+ ... print(path, type(module).__name__)
3156+ ...
3157+ bias Param
3158+ kernel Param
3159+
3160+ Args:
3161+ node: A graph node object.
3162+ graph: If ``True`` (default), uses graph-mode which supports the full
3163+ NNX feature set including shared references. If ``False``, uses
3164+ tree-mode which treats Modules as regular JAX pytrees, avoiding
3165+ the overhead of the graph protocol.
3166+ """
3167+ if graph is None :
3168+ graph = set_graph_mode .current_value ()
3169+ if graph :
3170+ node_impl = get_node_impl (node )
3171+ if node_impl is None :
3172+ raise ValueError (
3173+ f'Expected a graph node, got { type (node ).__name__ } . '
3174+ 'If this is a regular pytree, use graph=False.'
3175+ )
3176+ node_dict = node_impl .node_dict (node )
3177+ for key , value in node_dict .items ():
3178+ if is_node_module (value ):
3179+ yield key , value
3180+ else :
3181+ _check_valid_pytree (node , 'iter_children' )
3182+ if not is_pytree_node (node , check_graph_registry = False ):
3183+ raise ValueError (
3184+ f'Expected a pytree node, got { type (node ).__name__ } . '
3185+ 'If this is a graph node, use graph=True.'
3186+ )
3187+ children , _ = jax .tree_util .tree_flatten_with_path (
3188+ node , is_leaf = lambda x : x is not node
3189+ )
3190+ for jax_key_path , child in children :
3191+ if is_node_module (child ):
3192+ key = _key_path_to_key (jax_key_path [0 ])
3193+ yield key , child
3194+
3195+
31313196def recursive_map (
31323197 f : tp .Callable [[PathParts , tp .Any ], tp .Any ],
31333198 node : tp .Any ,
0 commit comments