Skip to content

Commit 93c430d

Browse files
Cristian GarciaFlax Authors
authored andcommitted
clean up custom_vjp's graph_updates=False section
PiperOrigin-RevId: 885315479
1 parent 1779706 commit 93c430d

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

flax/nnx/transforms/autodiff.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,24 +1471,18 @@ def custom_vjp(
14711471
defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute``
14721472
filter to keep ``custom_vjp`` and ``grad`` in sync.
14731473
1474-
**graph_updates behavior**
1475-
1476-
When ``graph_updates=True`` and ``graph=True``, the ``bwd`` function
1477-
receives gradients as ``(input_updates_g, out_g)`` where ``input_updates_g`` is a
1478-
tuple of ``nnx.State`` objects (one per Module argument) representing the gradient
1479-
of the updated state. Non-Module arguments appear as ``None``. The ``bwd`` function
1480-
must return tangents with the same structure, using ``State`` objects for Module terms.
1481-
In this mode, state mutations inside ``f`` are propagated to the inputs.
1474+
**graph_updates=False**
14821475
14831476
When ``graph_updates=False`` or ``graph=False``, the behavior is closer to
14841477
``jax.custom_vjp``: the ``bwd`` function receives ``out_g`` directly, and
1485-
tangents for Module arguments are Module instances (or clones) with gradient
1486-
values set on their fields. This mode does not support ``DiffState`` in
1487-
``nondiff_argnums``. Additionally, Variables in differentiable arguments cannot
1488-
not be mutated inside ``f``. If mutations are needed, pass the
1489-
relevant Variables through a non-differentiable argument instead.
1478+
tangent types are the same as the input types, this means the tangent for a
1479+
Module is a Module instance with gradient values set on its attributes.
1480+
This mode does not support ``DiffState`` in ``nondiff_argnums``. Additionally,
1481+
Variables in differentiable arguments cannot be mutated inside ``f``. If
1482+
mutations are needed, pass the relevant Variables through a non-differentiable
1483+
argument instead.
14901484
1491-
Example with ``graph_updates=False``::
1485+
Example::
14921486
14931487
>>> @nnx.custom_vjp(graph_updates=False)
14941488
... def f(m: Foo):
@@ -1503,6 +1497,8 @@ def custom_vjp(
15031497
... m_g.x[...] = cos_x * g * m.y
15041498
... m_g.y[...] = sin_x * g
15051499
... return (m_g,)
1500+
...
1501+
>>> f.defvjp(f_fwd, f_bwd)
15061502
15071503
Args:
15081504
fun: Callable base function.

0 commit comments

Comments
 (0)