@@ -296,6 +296,10 @@ def is_node_type(x: type[tp.Any]) -> bool:
296296 return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree
297297
298298
299+ def is_node_module (x : tp .Any ) -> bool :
300+ return type (x ) in GRAPH_REGISTRY
301+
302+
299303def get_node_impl (x : Node ) -> NodeImpl [Node , tp .Any , tp .Any ] | None :
300304 if isinstance (x , Variable ):
301305 return None
@@ -324,6 +328,7 @@ def get_node_impl_for_type(
324328 else :
325329 return None
326330
331+
327332# use type-aware sorting to support int keys
328333def _type_aware_sort (item : tuple [tp .Any , tp .Any ]) -> tuple [int , tp .Any ]:
329334 key , _ = item
@@ -571,6 +576,7 @@ class NodeAttr:
571576
572577NODE_ATTR = NodeAttr ()
573578
579+
574580@dataclasses .dataclass (frozen = True , slots = True )
575581class LeafAttr :
576582 pass
@@ -839,6 +845,7 @@ def flatten( # type: ignore[invalid-annotation]
839845 else :
840846 return graphdef , leaves
841847
848+
842849@dataclasses .dataclass (frozen = True , slots = True )
843850class DataElem :
844851 value : tp .Any
@@ -848,6 +855,7 @@ class DataElem:
848855class StaticElem :
849856 value : tp .Any
850857
858+
851859def _graph_flatten (
852860 node : Node ,
853861 node_impl : NodeImpl [Node , Leaf , AuxData ] | None ,
@@ -3062,6 +3070,62 @@ def iter_children(
30623070 yield key , child
30633071
30643072
3073+ def iter_module_children (
3074+ node : tp .Any , / , * , graph : bool | None = None ,
3075+ ) -> tp .Iterator [tuple [Key , tp .Any ]]:
3076+ """Iterates over all module children of a given node. This function is similar
3077+ to :func:`iter_children`, except it only iterates over the module children
3078+ only.
3079+
3080+ Example::
3081+
3082+ >>> from flax import nnx
3083+ ...
3084+ >>> model = nnx.Linear(2, 5, rngs=nnx.Rngs(0))
3085+ >>> for path, module in nnx.iter_module_children(model):
3086+ ... print(path, type(module).__name__)
3087+ ...
3088+ >>> for path, module in nnx.iter_children(model):
3089+ ... print(path, type(module).__name__)
3090+ ...
3091+ kernel Param
3092+
3093+ Args:
3094+ node: A graph node object.
3095+ graph: If ``True`` (default), uses graph-mode which supports the full
3096+ NNX feature set including shared references. If ``False``, uses
3097+ tree-mode which treats Modules as regular JAX pytrees, avoiding
3098+ the overhead of the graph protocol.
3099+ """
3100+ if graph is None :
3101+ graph = set_graph_mode .current_value ()
3102+ if graph :
3103+ node_impl = get_node_impl (node )
3104+ if node_impl is None :
3105+ raise ValueError (
3106+ f'Expected a graph node, got { type (node ).__name__ } . '
3107+ 'If this is a regular pytree, use graph=False.'
3108+ )
3109+ node_dict = node_impl .node_dict (node )
3110+ for key , value in node_dict .items ():
3111+ if is_node_module (value ):
3112+ yield key , value
3113+ else :
3114+ _check_valid_pytree (node , 'iter_children' )
3115+ if not is_pytree_node (node , check_graph_registry = False ):
3116+ raise ValueError (
3117+ f'Expected a pytree node, got { type (node ).__name__ } . '
3118+ 'If this is a graph node, use graph=True.'
3119+ )
3120+ children , _ = jax .tree_util .tree_flatten_with_path (
3121+ node , is_leaf = lambda x : x is not node
3122+ )
3123+ for jax_key_path , child in children :
3124+ if is_node_module (child ):
3125+ key = _key_path_to_key (jax_key_path [0 ])
3126+ yield key , child
3127+
3128+
30653129def recursive_map (
30663130 f : tp .Callable [[PathParts , tp .Any ], tp .Any ],
30673131 node : tp .Any ,
0 commit comments