We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9c33569 commit ba01eeaCopy full SHA for ba01eea
flax/nnx/graphlib.py
@@ -3038,7 +3038,7 @@ def iter_children(
3038
)
3039
node_dict = node_impl.node_dict(node)
3040
for key, value in node_dict.items():
3041
- if is_graph_node(value):
+ if type(value) in GRAPH_REGISTRY:
3042
yield key, value
3043
else:
3044
_check_valid_pytree(node, 'iter_children')
@@ -3051,7 +3051,7 @@ def iter_children(
3051
node, is_leaf=lambda x: x is not node
3052
3053
for jax_key_path, child in children:
3054
- if is_graph_node(child):
+ if type(child) in GRAPH_REGISTRY:
3055
key = _key_path_to_key(jax_key_path[0])
3056
yield key, child
3057
0 commit comments