diff --git a/docs_nnx/guides/pytree.ipynb b/docs_nnx/guides/pytree.ipynb index b0bf99405..0209be2e7 100644 --- a/docs_nnx/guides/pytree.ipynb +++ b/docs_nnx/guides/pytree.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 19, "id": "9b2b929d", "metadata": {}, "outputs": [ @@ -689,7 +689,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 26, "id": "668db479", "metadata": {}, "outputs": [ @@ -957,12 +957,12 @@ "metadata": {}, "source": [ "### sow\n", - "`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, NNX APIs such as `nnx.state` or `nnx.pop` are a good way of retrieving the sowed state, however `pop` is recommended because it explicitly removes the temporary state from the Module." + "`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, the `nnx.capture` function allows you to retrieve the sowed state. See the *Extracting Intermediate Values* guide for more details." ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 4, "id": "ca9f58a2", "metadata": {}, "outputs": [ @@ -1016,8 +1016,7 @@ "\n", "model = MLP(num_layers=3, dim=20, rngs=nnx.Rngs(0))\n", "x = jnp.ones((10, 20))\n", - "y = model(x)\n", - "intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values\n", + "_, intermediates = nnx.capture(model, nnx.Intermediate)(x)\n", "print(intermediates)" ] }, @@ -1027,7 +1026,7 @@ "metadata": {}, "source": [ "### perturb\n", - "`perturb` is similar to `sow` but it aims to capture the gradient of a value, currently this is a two step process although it might be simplified in the future:\n", + "`perturb` is similar to `sow` but it aims to capture the gradient of a value. This is a two step process:\n", "1. Initialize the pertubation state by running the model once.\n", "2. Pass the perturbation state as a differentiable target to `grad`.\n", "\n", @@ -1036,7 +1035,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "id": "41398e14", "metadata": {}, "outputs": [ @@ -1044,7 +1043,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "nnx.state(model, nnx.Perturbation) = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + "\u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[0., 0., 0.]], dtype=float32)\n", " \u001b[38;2;255;213;3m)\u001b[0m\n", @@ -1053,8 +1052,6 @@ } ], "source": [ - "import optax\n", - "\n", "class Model(nnx.Module):\n", " def __init__(self, rngs):\n", " self.linear1 = nnx.Linear(2, 3, rngs=rngs)\n", @@ -1067,10 +1064,9 @@ "\n", "rngs = nnx.Rngs(0)\n", "model = Model(rngs)\n", - "optimizer = nnx.Optimizer(model, tx=optax.sgd(1e-1), wrt=nnx.Param)\n", "x, y = rngs.uniform((1, 2)), rngs.uniform((1, 4))\n", - "_ = model(x) # initialize perturbations\n", - "print(f\"{nnx.state(model, nnx.Perturbation) = !s}\")" + "_, perturbations = nnx.capture(model, nnx.Perturbation)(x) # initialize perturbations\n", + "print(perturbations)" ] }, { @@ -1078,49 +1074,24 @@ "id": "c9221005", "metadata": {}, "source": [ - "Next we'll create a training step function that differentiates w.r.t. both the parameters of the model and the perturbations, the later will be the gradients for the intermediate values. `nnx.jit` and `nnx.value_and_grad` will be use to automatically propagate state updates. We'll return the `loss` function and the itermediate gradients." + "Next we'll differentiate a loss w.r.t. both the parameters of the model and the perturbations: the latter will be the gradients for the intermediate values." ] }, { "cell_type": "code", - "execution_count": 24, - "id": "d10effba", + "execution_count": 15, + "id": "56cc9fde-ce49-49ef-84ab-01fa5c91914b", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step = 0, loss = Array(0.7326511, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", - " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.430146 , -0.14356601, 0.2935633 ]], dtype=float32)\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m})\u001b[0m\n", - "step = 1, loss = Array(0.65039134, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", - " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.38535568, -0.11745065, 0.24441527]], dtype=float32)\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m})\u001b[0m\n" - ] - } - ], + "outputs": [], "source": [ - "@nnx.jit\n", - "def train_step(model, optimizer, x, y):\n", - " graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)\n", - "\n", - " def loss_fn(params, perturbations):\n", - " model = nnx.merge(graphdef, params, perturbations)\n", - " return jnp.mean((model(x) - y) ** 2)\n", - "\n", - " loss, (grads, iterm_grads) = nnx.value_and_grad(loss_fn, argnums=(0, 1))(params, perturbations)\n", - " optimizer.update(model, grads)\n", - "\n", - " return loss, iterm_grads\n", + "@nnx.value_and_grad(argnums=(0, 1))\n", + "def loss_grad(model, perturbations, x, y):\n", + " def loss_fn(model):\n", + " return jnp.mean((model(x) - y) ** 2)\n", + " return nnx.capture(loss_fn, init=perturbations)(model)\n", "\n", - "for step in range(2):\n", - " loss, iterm_grads = train_step(model, optimizer, x, y)\n", - " print(f\"{step = }, {loss = }, {iterm_grads = !s}\")" + "loss, (weight_grads, interm_grads) = loss_grad(model, perturbations, x, y)\n", + "interm_grads" ] }, { @@ -1194,7 +1165,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.13.5" } }, "nbformat": 4, diff --git a/docs_nnx/guides/pytree.md b/docs_nnx/guides/pytree.md index 2c9f46caf..7bec38b6e 100644 --- a/docs_nnx/guides/pytree.md +++ b/docs_nnx/guides/pytree.md @@ -430,7 +430,7 @@ NNX Modules are `Pytree`s that have two additional methods for traking intermedi +++ ### sow -`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, NNX APIs such as `nnx.state` or `nnx.pop` are a good way of retrieving the sowed state, however `pop` is recommended because it explicitly removes the temporary state from the Module. +`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, the `nnx.capture` function allows you to retrieve the sowed state. See the *Extracting Intermediate Values* guide for more details. ```{code-cell} ipython3 class Block(nnx.Module): @@ -456,21 +456,18 @@ class MLP(nnx.Module): model = MLP(num_layers=3, dim=20, rngs=nnx.Rngs(0)) x = jnp.ones((10, 20)) -y = model(x) -intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values +_, intermediates = nnx.capture(model, nnx.Intermediate)(x) print(intermediates) ``` ### perturb -`perturb` is similar to `sow` but it aims to capture the gradient of a value, currently this is a two step process although it might be simplified in the future: +`perturb` is similar to `sow` but it aims to capture the gradient of a value. This is a two step process: 1. Initialize the pertubation state by running the model once. 2. Pass the perturbation state as a differentiable target to `grad`. As an example lets create a simple model and use `perturb` to get the intermediate gradient `xgrad` for the variable `x`, and initialize the perturbations: ```{code-cell} ipython3 -import optax - class Model(nnx.Module): def __init__(self, rngs): self.linear1 = nnx.Linear(2, 3, rngs=rngs) @@ -483,31 +480,22 @@ class Model(nnx.Module): rngs = nnx.Rngs(0) model = Model(rngs) -optimizer = nnx.Optimizer(model, tx=optax.sgd(1e-1), wrt=nnx.Param) x, y = rngs.uniform((1, 2)), rngs.uniform((1, 4)) -_ = model(x) # initialize perturbations -print(f"{nnx.state(model, nnx.Perturbation) = !s}") +_, perturbations = nnx.capture(model, nnx.Perturbation)(x) # initialize perturbations +print(perturbations) ``` -Next we'll create a training step function that differentiates w.r.t. both the parameters of the model and the perturbations, the later will be the gradients for the intermediate values. `nnx.jit` and `nnx.value_and_grad` will be use to automatically propagate state updates. We'll return the `loss` function and the itermediate gradients. +Next we'll differentiate a loss w.r.t. both the parameters of the model and the perturbations: the latter will be the gradients for the intermediate values. ```{code-cell} ipython3 -@nnx.jit -def train_step(model, optimizer, x, y): - graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation) - - def loss_fn(params, perturbations): - model = nnx.merge(graphdef, params, perturbations) - return jnp.mean((model(x) - y) ** 2) - - loss, (grads, iterm_grads) = nnx.value_and_grad(loss_fn, argnums=(0, 1))(params, perturbations) - optimizer.update(model, grads) - - return loss, iterm_grads - -for step in range(2): - loss, iterm_grads = train_step(model, optimizer, x, y) - print(f"{step = }, {loss = }, {iterm_grads = !s}") +@nnx.value_and_grad(argnums=(0, 1)) +def loss_grad(model, perturbations, x, y): + def loss_fn(model): + return jnp.mean((model(x) - y) ** 2) + return nnx.capture(loss_fn, init=perturbations)(model) + +loss, (weight_grads, interm_grads) = loss_grad(model, perturbations, x, y) +interm_grads ``` ## Object