Skip to content

Commit 9c33569

Browse files
Cristian GarciaFlax Authors
authored andcommitted
improve error messages for tree mode errors
PiperOrigin-RevId: 878177405
1 parent bbfab0a commit 9c33569

File tree

4 files changed

+60
-17
lines changed

4 files changed

+60
-17
lines changed

flax/nnx/graphlib.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@
4848
KeyT = tp.TypeVar('KeyT', bound=Key)
4949

5050
Index = int
51+
52+
def _tree_mode_suggestion(fn_name: str) -> str:
53+
return (
54+
f'\n\nIf the structure is intended to be a graph, consider '
55+
f'using graph=True or nnx.graph.{fn_name}.'
56+
)
57+
58+
def _check_valid_pytree(node: tp.Any, fn_name: str) -> None:
59+
from flax.nnx import pytreelib
60+
if (
61+
isinstance(node, pytreelib.Pytree)
62+
and not node._pytree__is_pytree
63+
):
64+
raise ValueError(
65+
f"Cannot use '{fn_name}' with graph=False on a "
66+
f"'{type(node).__name__}' instance that has pytree=False. "
67+
f"Pytree subclasses with pytree=False are not registered as "
68+
f"JAX pytrees and cannot be used in tree-mode. "
69+
+ _tree_mode_suggestion(fn_name)
70+
)
71+
5172
Names = tp.Sequence[int]
5273
Node = tp.TypeVar('Node')
5374
Leaf = tp.TypeVar('Leaf')
@@ -637,9 +658,13 @@ def _tree_flatten(
637658
leaves: list[tp.Any],
638659
paths: list[PathParts] | None,
639660
) -> None:
640-
is_variable = lambda x: isinstance(x, Variable)
661+
def _is_leaf(x):
662+
if isinstance(x, Variable):
663+
return True
664+
_check_valid_pytree(x, 'flatten')
665+
return False
641666
jax_leaves, treedef = jax.tree_util.tree_flatten_with_path(
642-
node, is_leaf=is_variable
667+
node, is_leaf=_is_leaf
643668
)
644669
nnx_paths_and_leaves: list[tuple[PathParts, tp.Any]] = [
645670
(jax_to_nnx_path(jax_path), value) for jax_path, value in jax_leaves
@@ -666,7 +691,8 @@ def _tree_flatten(
666691
if var_id in seen_variables:
667692
raise ValueError(
668693
f'Duplicate Variable found at path {nnx_path!r}. '
669-
'Tree mode (graph=False) does not support shared references.'
694+
'Tree mode (graph=False) does not support shared references. '
695+
+ _tree_mode_suggestion('split')
670696
)
671697
seen_variables.add(var_id)
672698
raw_value = value.get_raw_value()
@@ -675,7 +701,8 @@ def _tree_flatten(
675701
if ref_id in seen_refs:
676702
raise ValueError(
677703
f'Duplicate Ref found inside Variable at path {nnx_path!r}. '
678-
'Tree mode (graph=False) does not support shared references.'
704+
'Tree mode (graph=False) does not support shared references. '
705+
+ _tree_mode_suggestion('split')
679706
)
680707
seen_refs.add(ref_id)
681708
nodes.append(VariableDef(
@@ -690,7 +717,8 @@ def _tree_flatten(
690717
if ref_id in seen_refs:
691718
raise ValueError(
692719
f'Duplicate Ref found at path {nnx_path!r}. '
693-
'Tree mode (graph=False) does not support shared references.'
720+
'Tree mode (graph=False) does not support shared references. '
721+
+ _tree_mode_suggestion('split')
694722
)
695723
seen_refs.add(ref_id)
696724
leaves.append(value)
@@ -2924,13 +2952,15 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29242952
continue
29252953

29262954
if not is_pytree_node(current, check_graph_registry=False):
2955+
_check_valid_pytree(current, 'iter_graph')
29272956
if isinstance(current, Variable) or variablelib.is_array_ref(current):
29282957
obj_id = id(current)
29292958
if obj_id in seen_refs:
29302959
raise ValueError(
29312960
f'Found duplicate Variable or Ref at path '
29322961
f'"{"/".join(map(str, path))}". '
2933-
'Shared references are not supported with graph=False.'
2962+
'Shared references are not supported with graph=False. '
2963+
+ _tree_mode_suggestion('iter_graph')
29342964
)
29352965
seen_refs.add(obj_id)
29362966
yield path, current
@@ -2940,7 +2970,8 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29402970
if obj_id in in_progress:
29412971
raise ValueError(
29422972
f'Found cycle at path "{"/".join(map(str, path))}". '
2943-
'Cycles are not supported with graph=False.'
2973+
'Cycles are not supported with graph=False. '
2974+
+ _tree_mode_suggestion('iter_graph')
29442975
)
29452976
in_progress.add(obj_id)
29462977

@@ -3010,6 +3041,7 @@ def iter_children(
30103041
if is_graph_node(value):
30113042
yield key, value
30123043
else:
3044+
_check_valid_pytree(node, 'iter_children')
30133045
if not is_pytree_node(node, check_graph_registry=False):
30143046
raise ValueError(
30153047
f'Expected a pytree node, got {type(node).__name__}. '
@@ -3123,13 +3155,15 @@ def _recursive_map_tree(
31233155

31243156
def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
31253157
if not is_pytree_node(current, check_graph_registry=False):
3158+
_check_valid_pytree(current, 'recursive_map')
31263159
if isinstance(current, Variable) or is_array_ref(current):
31273160
obj_id = id(current)
31283161
if obj_id in seen_refs:
31293162
raise ValueError(
31303163
f'Found duplicate Variable or Ref at path '
31313164
f'"{"/".join(map(str, path))}". '
3132-
'Shared references are not supported with graph=False.'
3165+
'Shared references are not supported with graph=False. '
3166+
+ _tree_mode_suggestion('recursive_map')
31333167
)
31343168
seen_refs.add(obj_id)
31353169
return f(path, current)
@@ -3138,7 +3172,8 @@ def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
31383172
if obj_id in in_progress:
31393173
raise ValueError(
31403174
f'Found cycle at path "{"/".join(map(str, path))}". '
3141-
'Cycles are not supported with graph=False.'
3175+
'Cycles are not supported with graph=False. '
3176+
+ _tree_mode_suggestion('recursive_map')
31423177
)
31433178
in_progress.add(obj_id)
31443179

flax/nnx/transforms/autodiff.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def _grad_general(
140140
if any(isinstance(x, DiffState) for x in jax.tree.leaves(argnums)):
141141
raise ValueError(
142142
'`argnums` cannot contain `DiffState` objects '
143-
'when `graph=False`'
143+
'when `graph=False`. '
144+
+ graphlib._tree_mode_suggestion('grad')
144145
)
145146

146147
gradded_fn = transform(
@@ -1321,7 +1322,8 @@ def custom_vjp(
13211322
if any(isinstance(x, DiffState) for x in nondiff_argnums):
13221323
raise ValueError(
13231324
'`nondiff_argnums` cannot contain `DiffState` objects '
1324-
'when `graph=False`'
1325+
'when `graph=False`. '
1326+
+ graphlib._tree_mode_suggestion('custom_vjp')
13251327
)
13261328
return TreeCustomVjp(fun_unbound, nondiff_argnums) # type: ignore[arg-type]
13271329

flax/nnx/transforms/compilation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,14 @@ def jit(
361361
if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)):
362362
raise ValueError(
363363
'`in_shardings` cannot contain `StateSharding` objects '
364-
'when `graph=False`'
364+
'when `graph=False`. '
365+
+ graphlib._tree_mode_suggestion('jit')
365366
)
366367
if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_shardings)):
367368
raise ValueError(
368369
'`out_shardings` cannot contain `StateSharding` objects '
369-
'when `graph=False`'
370+
'when `graph=False`. '
371+
+ graphlib._tree_mode_suggestion('jit')
370372
)
371373

372374
wrapped_cls = JitWrapped if graph else TreeJitWrapped

flax/nnx/transforms/iteration.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,14 @@ def vmap(
386386
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)):
387387
raise ValueError(
388388
'`in_axes` cannot contain `StateAxes` objects '
389-
'when `graph=False`'
389+
'when `graph=False`. '
390+
+ graphlib._tree_mode_suggestion('vmap')
390391
)
391392
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)):
392393
raise ValueError(
393394
'`out_axes` cannot contain `StateAxes` objects '
394-
'when `graph=False`'
395+
'when `graph=False`. '
396+
+ graphlib._tree_mode_suggestion('vmap')
395397
)
396398

397399
vmapped_fn = jax.vmap(
@@ -648,12 +650,14 @@ def pmap(
648650
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)):
649651
raise ValueError(
650652
'`in_axes` cannot contain `StateAxes` objects '
651-
'when `graph=False`'
653+
'when `graph=False`. '
654+
+ graphlib._tree_mode_suggestion('pmap')
652655
)
653656
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)):
654657
raise ValueError(
655658
'`out_axes` cannot contain `StateAxes` objects '
656-
'when `graph=False`'
659+
'when `graph=False`. '
660+
+ graphlib._tree_mode_suggestion('pmap')
657661
)
658662

659663
pmapped_fn = jax.pmap(

0 commit comments

Comments
 (0)