@@ -1471,24 +1471,18 @@ def custom_vjp(
14711471 defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute``
14721472 filter to keep ``custom_vjp`` and ``grad`` in sync.
14731473
1474- **graph_updates behavior**
1475-
1476- When ``graph_updates=True`` and ``graph=True``, the ``bwd`` function
1477- receives gradients as ``(input_updates_g, out_g)`` where ``input_updates_g`` is a
1478- tuple of ``nnx.State`` objects (one per Module argument) representing the gradient
1479- of the updated state. Non-Module arguments appear as ``None``. The ``bwd`` function
1480- must return tangents with the same structure, using ``State`` objects for Module terms.
1481- In this mode, state mutations inside ``f`` are propagated to the inputs.
1474+ **graph_updates=False**
14821475
14831476 When ``graph_updates=False`` or ``graph=False``, the behavior is closer to
14841477 ``jax.custom_vjp``: the ``bwd`` function receives ``out_g`` directly, and
1485- tangents for Module arguments are Module instances (or clones) with gradient
1486- values set on their fields. This mode does not support ``DiffState`` in
1487- ``nondiff_argnums``. Additionally, Variables in differentiable arguments cannot
1488- not be mutated inside ``f``. If mutations are needed, pass the
1489- relevant Variables through a non-differentiable argument instead.
1478+ tangent types are the same as the input types, this means the tangent for a
1479+ Module is a Module instance with gradient values set on its attributes.
1480+ This mode does not support ``DiffState`` in ``nondiff_argnums``. Additionally,
1481+ Variables in differentiable arguments cannot be mutated inside ``f``. If
1482+ mutations are needed, pass the relevant Variables through a non-differentiable
1483+ argument instead.
14901484
1491- Example with ``graph_updates=False`` ::
1485+ Example::
14921486
14931487 >>> @nnx.custom_vjp(graph_updates=False)
14941488 ... def f(m: Foo):
@@ -1503,6 +1497,8 @@ def custom_vjp(
15031497 ... m_g.x[...] = cos_x * g * m.y
15041498 ... m_g.y[...] = sin_x * g
15051499 ... return (m_g,)
1500+ ...
1501+ >>> f.defvjp(f_fwd, f_bwd)
15061502
15071503 Args:
15081504 fun: Callable base function.
0 commit comments