4848KeyT = tp .TypeVar ('KeyT' , bound = Key )
4949
5050Index = int
51+
52+ def _tree_mode_suggestion (fn_name : str ) -> str :
53+ return (
54+ f'\n \n If the structure is intended to be a graph, consider '
55+ f'using graph=True or nnx.graph.{ fn_name } .'
56+ )
57+
58+ def _check_valid_pytree (node : tp .Any , fn_name : str ) -> None :
59+ from flax .nnx import pytreelib
60+ if (
61+ isinstance (node , pytreelib .Pytree )
62+ and not node ._pytree__is_pytree
63+ ):
64+ raise ValueError (
65+ f"Cannot use '{ fn_name } ' with graph=False on a "
66+ f"'{ type (node ).__name__ } ' instance that has pytree=False. "
67+ f"Pytree subclasses with pytree=False are not registered as "
68+ f"JAX pytrees and cannot be used in tree-mode. "
69+ + _tree_mode_suggestion (fn_name )
70+ )
71+
5172Names = tp .Sequence [int ]
5273Node = tp .TypeVar ('Node' )
5374Leaf = tp .TypeVar ('Leaf' )
@@ -637,9 +658,13 @@ def _tree_flatten(
637658 leaves : list [tp .Any ],
638659 paths : list [PathParts ] | None ,
639660) -> None :
640- is_variable = lambda x : isinstance (x , Variable )
661+ def _is_leaf (x ):
662+ if isinstance (x , Variable ):
663+ return True
664+ _check_valid_pytree (x , 'flatten' )
665+ return False
641666 jax_leaves , treedef = jax .tree_util .tree_flatten_with_path (
642- node , is_leaf = is_variable
667+ node , is_leaf = _is_leaf
643668 )
644669 nnx_paths_and_leaves : list [tuple [PathParts , tp .Any ]] = [
645670 (jax_to_nnx_path (jax_path ), value ) for jax_path , value in jax_leaves
@@ -666,7 +691,8 @@ def _tree_flatten(
666691 if var_id in seen_variables :
667692 raise ValueError (
668693 f'Duplicate Variable found at path { nnx_path !r} . '
669- 'Tree mode (graph=False) does not support shared references.'
694+ 'Tree mode (graph=False) does not support shared references. '
695+ + _tree_mode_suggestion ('split' )
670696 )
671697 seen_variables .add (var_id )
672698 raw_value = value .get_raw_value ()
@@ -675,7 +701,8 @@ def _tree_flatten(
675701 if ref_id in seen_refs :
676702 raise ValueError (
677703 f'Duplicate Ref found inside Variable at path { nnx_path !r} . '
678- 'Tree mode (graph=False) does not support shared references.'
704+ 'Tree mode (graph=False) does not support shared references. '
705+ + _tree_mode_suggestion ('split' )
679706 )
680707 seen_refs .add (ref_id )
681708 nodes .append (VariableDef (
@@ -690,7 +717,8 @@ def _tree_flatten(
690717 if ref_id in seen_refs :
691718 raise ValueError (
692719 f'Duplicate Ref found at path { nnx_path !r} . '
693- 'Tree mode (graph=False) does not support shared references.'
720+ 'Tree mode (graph=False) does not support shared references. '
721+ + _tree_mode_suggestion ('split' )
694722 )
695723 seen_refs .add (ref_id )
696724 leaves .append (value )
@@ -2924,13 +2952,15 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29242952 continue
29252953
29262954 if not is_pytree_node (current , check_graph_registry = False ):
2955+ _check_valid_pytree (current , 'iter_graph' )
29272956 if isinstance (current , Variable ) or variablelib .is_array_ref (current ):
29282957 obj_id = id (current )
29292958 if obj_id in seen_refs :
29302959 raise ValueError (
29312960 f'Found duplicate Variable or Ref at path '
29322961 f'"{ "/" .join (map (str , path ))} ". '
2933- 'Shared references are not supported with graph=False.'
2962+ 'Shared references are not supported with graph=False. '
2963+ + _tree_mode_suggestion ('iter_graph' )
29342964 )
29352965 seen_refs .add (obj_id )
29362966 yield path , current
@@ -2940,7 +2970,8 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29402970 if obj_id in in_progress :
29412971 raise ValueError (
29422972 f'Found cycle at path "{ "/" .join (map (str , path ))} ". '
2943- 'Cycles are not supported with graph=False.'
2973+ 'Cycles are not supported with graph=False. '
2974+ + _tree_mode_suggestion ('iter_graph' )
29442975 )
29452976 in_progress .add (obj_id )
29462977
@@ -3010,6 +3041,7 @@ def iter_children(
30103041 if is_graph_node (value ):
30113042 yield key , value
30123043 else :
3044+ _check_valid_pytree (node , 'iter_children' )
30133045 if not is_pytree_node (node , check_graph_registry = False ):
30143046 raise ValueError (
30153047 f'Expected a pytree node, got { type (node ).__name__ } . '
@@ -3123,13 +3155,15 @@ def _recursive_map_tree(
31233155
31243156 def _recurse (path : PathParts , current : tp .Any ) -> tp .Any :
31253157 if not is_pytree_node (current , check_graph_registry = False ):
3158+ _check_valid_pytree (current , 'recursive_map' )
31263159 if isinstance (current , Variable ) or is_array_ref (current ):
31273160 obj_id = id (current )
31283161 if obj_id in seen_refs :
31293162 raise ValueError (
31303163 f'Found duplicate Variable or Ref at path '
31313164 f'"{ "/" .join (map (str , path ))} ". '
3132- 'Shared references are not supported with graph=False.'
3165+ 'Shared references are not supported with graph=False. '
3166+ + _tree_mode_suggestion ('recursive_map' )
31333167 )
31343168 seen_refs .add (obj_id )
31353169 return f (path , current )
@@ -3138,7 +3172,8 @@ def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
31383172 if obj_id in in_progress :
31393173 raise ValueError (
31403174 f'Found cycle at path "{ "/" .join (map (str , path ))} ". '
3141- 'Cycles are not supported with graph=False.'
3175+ 'Cycles are not supported with graph=False. '
3176+ + _tree_mode_suggestion ('recursive_map' )
31423177 )
31433178 in_progress .add (obj_id )
31443179
0 commit comments