Skip to content

Commit ba01eea

Browse files
hyunn9973Flax Authors
authored andcommitted
Iterate over modules only for modules through iter_children().
PiperOrigin-RevId: 878177716
1 parent 9c33569 commit ba01eea

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flax/nnx/graphlib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3038,7 +3038,7 @@ def iter_children(
30383038
)
30393039
node_dict = node_impl.node_dict(node)
30403040
for key, value in node_dict.items():
3041-
if is_graph_node(value):
3041+
if type(value) in GRAPH_REGISTRY:
30423042
yield key, value
30433043
else:
30443044
_check_valid_pytree(node, 'iter_children')
@@ -3051,7 +3051,7 @@ def iter_children(
30513051
node, is_leaf=lambda x: x is not node
30523052
)
30533053
for jax_key_path, child in children:
3054-
if is_graph_node(child):
3054+
if type(child) in GRAPH_REGISTRY:
30553055
key = _key_path_to_key(jax_key_path[0])
30563056
yield key, child
30573057

0 commit comments

Comments
 (0)