Skip to content

Commit e067f39

Browse files
Cristian GarciaFlax Authors
authored andcommitted
add error handling to iter_children
PiperOrigin-RevId: 877649499
1 parent 972c34e commit e067f39

File tree

3 files changed

+73
-59
lines changed

3 files changed

+73
-59
lines changed

flax/nnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
view = functools.partial(_module.view, graph=True)
4444
view_info = functools.partial(_module.view_info, graph=True)
4545
iter_modules = functools.partial(_module.iter_modules, graph=True)
46-
iter_children = functools.partial(_module.iter_children, graph=True)
46+
iter_children = functools.partial(_module.iter_children, graph=True) # type: ignore[has-type]
4747

4848
# rnglib
4949
split_rngs = functools.partial(_rnglib.split_rngs, graph=True)

flax/nnx/graphlib.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,6 +2953,77 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29532953
stack.append(((*path, key), child, False))
29542954

29552955

2956+
def iter_children(
2957+
node: tp.Any, /, *, graph: bool | None = None,
2958+
) -> tp.Iterator[tuple[Key, tp.Any]]:
2959+
"""Iterates over all immediate child nodes of a given node. This
2960+
function is similar to :func:`iter_graph`, except it only iterates over the
2961+
immediate children, and does not recurse further down.
2962+
2963+
Specifically, this function creates a generator that yields the key and the
2964+
child node instance, where the key is a string representing the attribute
2965+
name to access the corresponding child.
2966+
2967+
Example::
2968+
2969+
>>> from flax import nnx
2970+
...
2971+
>>> class SubModule(nnx.Module):
2972+
... def __init__(self, din, dout, rngs):
2973+
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
2974+
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
2975+
...
2976+
>>> class Block(nnx.Module):
2977+
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
2978+
... self.linear = nnx.Linear(din, dout, rngs=rngs)
2979+
... self.submodule = SubModule(din, dout, rngs=rngs)
2980+
... self.dropout = nnx.Dropout(0.5)
2981+
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
2982+
...
2983+
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
2984+
>>> for path, module in nnx.iter_children(model):
2985+
... print(path, type(module).__name__)
2986+
...
2987+
batch_norm BatchNorm
2988+
dropout Dropout
2989+
linear Linear
2990+
submodule SubModule
2991+
2992+
Args:
2993+
node: A graph node object.
2994+
graph: If ``True`` (default), uses graph-mode which supports the full
2995+
NNX feature set including shared references. If ``False``, uses
2996+
tree-mode which treats Modules as regular JAX pytrees, avoiding
2997+
the overhead of the graph protocol.
2998+
"""
2999+
if graph is None:
3000+
graph = set_graph_mode.current_value()
3001+
if graph:
3002+
node_impl = get_node_impl(node)
3003+
if node_impl is None:
3004+
raise ValueError(
3005+
f'Expected a graph node, got {type(node).__name__}. '
3006+
'If this is a regular pytree, use graph=False.'
3007+
)
3008+
node_dict = node_impl.node_dict(node)
3009+
for key, value in node_dict.items():
3010+
if is_graph_node(value):
3011+
yield key, value
3012+
else:
3013+
if not is_pytree_node(node, check_graph_registry=False):
3014+
raise ValueError(
3015+
f'Expected a pytree node, got {type(node).__name__}. '
3016+
'If this is a graph node, use graph=True.'
3017+
)
3018+
children, _ = jax.tree_util.tree_flatten_with_path(
3019+
node, is_leaf=lambda x: x is not node
3020+
)
3021+
for jax_key_path, child in children:
3022+
if is_graph_node(child):
3023+
key = _key_path_to_key(jax_key_path[0])
3024+
yield key, child
3025+
3026+
29563027
def recursive_map(
29573028
f: tp.Callable[[PathParts, tp.Any], tp.Any],
29583029
node: tp.Any,

flax/nnx/module.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -671,61 +671,4 @@ def iter_modules(
671671
if isinstance(value, Module):
672672
yield path, value
673673

674-
def iter_children(module: Module, graph: bool | None = None) -> tp.Iterator[tuple[Key, Module]]:
675-
"""Iterates over all children :class:`Module`'s of a given Module. This
676-
method is similar to :func:`iter_modules`, except it only iterates over the
677-
immediate children, and does not recurse further down.
678-
679-
Specifically, this function creates a generator that yields the key and the Module instance,
680-
where the key is a string representing the attribute name of the Module to access
681-
the corresponding child Module.
682-
683-
Example::
684-
685-
>>> from flax import nnx
686-
...
687-
>>> class SubModule(nnx.Module):
688-
... def __init__(self, din, dout, rngs):
689-
... self.linear1 = nnx.Linear(din, dout, rngs=rngs)
690-
... self.linear2 = nnx.Linear(din, dout, rngs=rngs)
691-
...
692-
>>> class Block(nnx.Module):
693-
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
694-
... self.linear = nnx.Linear(din, dout, rngs=rngs)
695-
... self.submodule = SubModule(din, dout, rngs=rngs)
696-
... self.dropout = nnx.Dropout(0.5)
697-
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
698-
...
699-
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
700-
>>> for path, module in nnx.iter_children(model):
701-
... print(path, type(module).__name__)
702-
...
703-
batch_norm BatchNorm
704-
dropout Dropout
705-
linear Linear
706-
submodule SubModule
707-
708-
Args:
709-
module: A :class:`Module` object.
710-
graph: If ``True`` (default), uses graph-mode which supports the full
711-
NNX feature set including shared references. If ``False``, uses
712-
tree-mode which treats Modules as regular JAX pytrees, avoiding
713-
the overhead of the graph protocol.
714-
"""
715-
if graph is None:
716-
graph = graphlib.set_graph_mode.current_value()
717-
if graph:
718-
node_impl = graphlib.get_node_impl(module)
719-
assert node_impl is not None
720-
node_dict = node_impl.node_dict(module)
721-
for key, value in node_dict.items():
722-
if isinstance(value, Module):
723-
yield key, value
724-
else:
725-
children, _ = jax.tree_util.tree_flatten_with_path(
726-
module, is_leaf=lambda x: x is not module
727-
)
728-
for jax_key_path, child in children:
729-
if isinstance(child, Module):
730-
key = graphlib._key_path_to_key(jax_key_path[0])
731-
yield key, child
674+
iter_children = graphlib.iter_children

0 commit comments

Comments
 (0)