Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 22 additions & 51 deletions docs_nnx/guides/pytree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 19,
"id": "9b2b929d",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -689,7 +689,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 26,
"id": "668db479",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -957,12 +957,12 @@
"metadata": {},
"source": [
"### sow\n",
"`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, NNX APIs such as `nnx.state` or `nnx.pop` are a good way of retrieving the sowed state, however `pop` is recommended because it explicitly removes the temporary state from the Module."
"`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, the `nnx.capture` function allows you to retrieve the sowed state. See the *Extracting Intermediate Values* guide for more details."
]
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 4,
"id": "ca9f58a2",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -1016,8 +1016,7 @@
"\n",
"model = MLP(num_layers=3, dim=20, rngs=nnx.Rngs(0))\n",
"x = jnp.ones((10, 20))\n",
"y = model(x)\n",
"intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values\n",
"_, intermediates = nnx.capture(model, nnx.Intermediate)(x)\n",
"print(intermediates)"
]
},
Expand All @@ -1027,7 +1026,7 @@
"metadata": {},
"source": [
"### perturb\n",
"`perturb` is similar to `sow` but it aims to capture the gradient of a value, currently this is a two step process although it might be simplified in the future:\n",
"`perturb` is similar to `sow` but it aims to capture the gradient of a value. This is a two step process:\n",
"1. Initialize the pertubation state by running the model once.\n",
"2. Pass the perturbation state as a differentiable target to `grad`.\n",
"\n",
Expand All @@ -1036,15 +1035,15 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 12,
"id": "41398e14",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nnx.state(model, nnx.Perturbation) = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
"\u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
" \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n",
" \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[0., 0., 0.]], dtype=float32)\n",
" \u001b[38;2;255;213;3m)\u001b[0m\n",
Expand All @@ -1053,8 +1052,6 @@
}
],
"source": [
"import optax\n",
"\n",
"class Model(nnx.Module):\n",
" def __init__(self, rngs):\n",
" self.linear1 = nnx.Linear(2, 3, rngs=rngs)\n",
Expand All @@ -1067,60 +1064,34 @@
"\n",
"rngs = nnx.Rngs(0)\n",
"model = Model(rngs)\n",
"optimizer = nnx.Optimizer(model, tx=optax.sgd(1e-1), wrt=nnx.Param)\n",
"x, y = rngs.uniform((1, 2)), rngs.uniform((1, 4))\n",
"_ = model(x) # initialize perturbations\n",
"print(f\"{nnx.state(model, nnx.Perturbation) = !s}\")"
"_, perturbations = nnx.capture(model, nnx.Perturbation)(x) # initialize perturbations\n",
"print(perturbations)"
]
},
{
"cell_type": "markdown",
"id": "c9221005",
"metadata": {},
"source": [
"Next we'll create a training step function that differentiates w.r.t. both the parameters of the model and the perturbations, the later will be the gradients for the intermediate values. `nnx.jit` and `nnx.value_and_grad` will be use to automatically propagate state updates. We'll return the `loss` function and the itermediate gradients."
"Next we'll differentiate a loss w.r.t. both the parameters of the model and the perturbations: the latter will be the gradients for the intermediate values."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "d10effba",
"execution_count": 15,
"id": "56cc9fde-ce49-49ef-84ab-01fa5c91914b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step = 0, loss = Array(0.7326511, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
" \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n",
" \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.430146 , -0.14356601, 0.2935633 ]], dtype=float32)\n",
" \u001b[38;2;255;213;3m)\u001b[0m\n",
"\u001b[38;2;255;213;3m})\u001b[0m\n",
"step = 1, loss = Array(0.65039134, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
" \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n",
" \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.38535568, -0.11745065, 0.24441527]], dtype=float32)\n",
" \u001b[38;2;255;213;3m)\u001b[0m\n",
"\u001b[38;2;255;213;3m})\u001b[0m\n"
]
}
],
"outputs": [],
"source": [
"@nnx.jit\n",
"def train_step(model, optimizer, x, y):\n",
" graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)\n",
"\n",
" def loss_fn(params, perturbations):\n",
" model = nnx.merge(graphdef, params, perturbations)\n",
" return jnp.mean((model(x) - y) ** 2)\n",
"\n",
" loss, (grads, iterm_grads) = nnx.value_and_grad(loss_fn, argnums=(0, 1))(params, perturbations)\n",
" optimizer.update(model, grads)\n",
"\n",
" return loss, iterm_grads\n",
"@nnx.value_and_grad(argnums=(0, 1))\n",
"def loss_grad(model, perturbations, x, y):\n",
" def loss_fn(model):\n",
" return jnp.mean((model(x) - y) ** 2)\n",
" return nnx.capture(loss_fn, init=perturbations)(model)\n",
"\n",
"for step in range(2):\n",
" loss, iterm_grads = train_step(model, optimizer, x, y)\n",
" print(f\"{step = }, {loss = }, {iterm_grads = !s}\")"
"loss, (weight_grads, interm_grads) = loss_grad(model, perturbations, x, y)\n",
"interm_grads"
]
},
{
Expand Down Expand Up @@ -1194,7 +1165,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.13.5"
}
},
"nbformat": 4,
Expand Down
40 changes: 14 additions & 26 deletions docs_nnx/guides/pytree.md
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ NNX Modules are `Pytree`s that have two additional methods for traking intermedi
+++

### sow
`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, NNX APIs such as `nnx.state` or `nnx.pop` are a good way of retrieving the sowed state, however `pop` is recommended because it explicitly removes the temporary state from the Module.
`sow` receives a `Variable` type, a `name`, and a `value`, and stores it in the `Module` so it can be retrieved at a later time. As the following example shows, the `nnx.capture` function allows you to retrieve the sowed state. See the *Extracting Intermediate Values* guide for more details.

```{code-cell} ipython3
class Block(nnx.Module):
Expand All @@ -456,21 +456,18 @@ class MLP(nnx.Module):

model = MLP(num_layers=3, dim=20, rngs=nnx.Rngs(0))
x = jnp.ones((10, 20))
y = model(x)
intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values
_, intermediates = nnx.capture(model, nnx.Intermediate)(x)
print(intermediates)
```

### perturb
`perturb` is similar to `sow` but it aims to capture the gradient of a value, currently this is a two step process although it might be simplified in the future:
`perturb` is similar to `sow` but it aims to capture the gradient of a value. This is a two step process:
1. Initialize the pertubation state by running the model once.
2. Pass the perturbation state as a differentiable target to `grad`.

As an example lets create a simple model and use `perturb` to get the intermediate gradient `xgrad` for the variable `x`, and initialize the perturbations:

```{code-cell} ipython3
import optax

class Model(nnx.Module):
def __init__(self, rngs):
self.linear1 = nnx.Linear(2, 3, rngs=rngs)
Expand All @@ -483,31 +480,22 @@ class Model(nnx.Module):

rngs = nnx.Rngs(0)
model = Model(rngs)
optimizer = nnx.Optimizer(model, tx=optax.sgd(1e-1), wrt=nnx.Param)
x, y = rngs.uniform((1, 2)), rngs.uniform((1, 4))
_ = model(x) # initialize perturbations
print(f"{nnx.state(model, nnx.Perturbation) = !s}")
_, perturbations = nnx.capture(model, nnx.Perturbation)(x) # initialize perturbations
print(perturbations)
```

Next we'll create a training step function that differentiates w.r.t. both the parameters of the model and the perturbations, the later will be the gradients for the intermediate values. `nnx.jit` and `nnx.value_and_grad` will be use to automatically propagate state updates. We'll return the `loss` function and the itermediate gradients.
Next we'll differentiate a loss w.r.t. both the parameters of the model and the perturbations: the latter will be the gradients for the intermediate values.

```{code-cell} ipython3
@nnx.jit
def train_step(model, optimizer, x, y):
graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)

def loss_fn(params, perturbations):
model = nnx.merge(graphdef, params, perturbations)
return jnp.mean((model(x) - y) ** 2)

loss, (grads, iterm_grads) = nnx.value_and_grad(loss_fn, argnums=(0, 1))(params, perturbations)
optimizer.update(model, grads)

return loss, iterm_grads

for step in range(2):
loss, iterm_grads = train_step(model, optimizer, x, y)
print(f"{step = }, {loss = }, {iterm_grads = !s}")
@nnx.value_and_grad(argnums=(0, 1))
def loss_grad(model, perturbations, x, y):
def loss_fn(model):
return jnp.mean((model(x) - y) ** 2)
return nnx.capture(loss_fn, init=perturbations)(model)

loss, (weight_grads, interm_grads) = loss_grad(model, perturbations, x, y)
interm_grads
```

## Object
Expand Down
Loading