@@ -2953,6 +2953,77 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29532953 stack .append (((* path , key ), child , False ))
29542954
29552955
2956+ def iter_children (
2957+ node : tp .Any , / , * , graph : bool | None = None ,
2958+ ) -> tp .Iterator [tuple [Key , tp .Any ]]:
2959+ """Iterates over all immediate child nodes of a given node. This
2960+ function is similar to :func:`iter_graph`, except it only iterates over the
2961+ immediate children, and does not recurse further down.
2962+
2963+ Specifically, this function creates a generator that yields the key and the
2964+ child node instance, where the key is a string representing the attribute
2965+ name to access the corresponding child.
2966+
2967+ Example::
2968+
2969+ >>> from flax import nnx
2970+ ...
2971+ >>> class SubModule(nnx.Module):
2972+ ... def __init__(self, din, dout, rngs):
2973+ ... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
2974+ ... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
2975+ ...
2976+ >>> class Block(nnx.Module):
2977+ ... def __init__(self, din, dout, *, rngs: nnx.Rngs):
2978+ ... self.linear = nnx.Linear(din, dout, rngs=rngs)
2979+ ... self.submodule = SubModule(din, dout, rngs=rngs)
2980+ ... self.dropout = nnx.Dropout(0.5)
2981+ ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
2982+ ...
2983+ >>> model = Block(2, 5, rngs=nnx.Rngs(0))
2984+ >>> for path, module in nnx.iter_children(model):
2985+ ... print(path, type(module).__name__)
2986+ ...
2987+ batch_norm BatchNorm
2988+ dropout Dropout
2989+ linear Linear
2990+ submodule SubModule
2991+
2992+ Args:
2993+ node: A graph node object.
2994+ graph: If ``True`` (default), uses graph-mode which supports the full
2995+ NNX feature set including shared references. If ``False``, uses
2996+ tree-mode which treats Modules as regular JAX pytrees, avoiding
2997+ the overhead of the graph protocol.
2998+ """
2999+ if graph is None :
3000+ graph = set_graph_mode .current_value ()
3001+ if graph :
3002+ node_impl = get_node_impl (node )
3003+ if node_impl is None :
3004+ raise ValueError (
3005+ f'Expected a graph node, got { type (node ).__name__ } . '
3006+ 'If this is a regular pytree, use graph=False.'
3007+ )
3008+ node_dict = node_impl .node_dict (node )
3009+ for key , value in node_dict .items ():
3010+ if is_graph_node (value ):
3011+ yield key , value
3012+ else :
3013+ if not is_pytree_node (node , check_graph_registry = False ):
3014+ raise ValueError (
3015+ f'Expected a pytree node, got { type (node ).__name__ } . '
3016+ 'If this is a graph node, use graph=True.'
3017+ )
3018+ children , _ = jax .tree_util .tree_flatten_with_path (
3019+ node , is_leaf = lambda x : x is not node
3020+ )
3021+ for jax_key_path , child in children :
3022+ if is_graph_node (child ):
3023+ key = _key_path_to_key (jax_key_path [0 ])
3024+ yield key , child
3025+
3026+
29563027def recursive_map (
29573028 f : tp .Callable [[PathParts , tp .Any ], tp .Any ],
29583029 node : tp .Any ,
0 commit comments