diff --git a/docs_nnx/nnx_basics_tree.ipynb b/docs_nnx/nnx_basics_tree.ipynb new file mode 100644 index 000000000..ee63ff443 --- /dev/null +++ b/docs_nnx/nnx_basics_tree.ipynb @@ -0,0 +1,675 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NNX Basics\n", + "\n", + "NNX is a Neural Networks library for JAX. NNX provides the tools to structure modeling code as [JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) so it can work with transforms, `jax.tree.*` utilities, and all standard JAX APIs. This guide covers the core concepts you need to get started." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "nnx.graphlib.set_graph_mode(False)\n", + "nnx.graphlib.set_graph_updates(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "NNX's main build blocks are:\n", + "\n", + "- **`nnx.Pytree`**: Base class for pytree-compatible objects. Defines the tree structure of your model.\n", + "- **`nnx.Variable`**: Wraps array data and tracks mutable state. Subclasses like `nnx.Param` categorize different kinds of state.\n", + "- **State APIs** (`nnx.{state, split, merge, update}`): Extract, partition, reconstruct, and apply state updates.\n", + "- **NNX Transforms** (`nnx.{jit, grad, scan, ...}`): Thin wrappers over JAX transforms that automate state propagation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pytrees and Variables\n", + "\n", + "`nnx.Pytree` and `nnx.Variable` are two orthogonal systems. **Pytrees** define the structure of your model as a JAX-compatible tree. **Variables** wrap array data and enable expressing state updates via in-place mutation. \n", + "\n", + "`Pytree`s are python objects that define its tree structure dynamically through its attributes, these are split into two categories: **Static attributes** (e.g. `int`, `str`) are embedded in the tree structure definition and are not traced by JAX. **Data attributes** (e.g. `nnx.Variable`, `jax.Array`) are the leaves of the tree and are traced by JAX. For more details see the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html).\n", + "\n", + "Here's a typical layer definition:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class Count(nnx.Variable): pass # custom Variable types\n", + "\n", + "class Linear(nnx.Pytree):\n", + " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.din, self.dout = din, dout # static attributes\n", + " self.w = nnx.Param(rngs.uniform((din, dout))) # data attribute\n", + " self.count = Count(jnp.array(0)) # data attribute\n", + "\n", + " def __call__(self, x: jax.Array):\n", + " self.count[...] += 1 # inplace state updates\n", + " return x @ self.w # Variable are Array-like\n", + "\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "nnx.display(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Note:** Most user code uses `nnx.Module`, which is a subclass of `nnx.Pytree` with additional features such as sopport for metric reporting.\n", + "\n", + "As we can see above, Variables are array-like; they support arithmetic operators, indexing, and can be used directly in JAX expressions. You can update their value in-place using `variable[...] = new_value`. Since NNX Pytrees are standard JAX pytrees, you can use `jax.tree.*` functions directly on them:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 5), model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "\n", + "model.w sum: 4.1854\n", + "doubled.w sum: 8.3709\n", + "\n", + "Pytree leaves:\n", + ".count.value: Array(1, dtype=int32, weak_type=True)\n", + ".w.value: Array([[0.8423141 , 0.18237865, 0.2271781 , 0.12072563, 0.19181347],\n", + " [0.722015 , 0.7654456 , 0.15254045, 0.9517063 , 0.02931046]], dtype=float32)\n" + ] + } + ], + "source": [ + "x = jnp.ones((3, 2))\n", + "y = model(x)\n", + "print(f'{y.shape = }, {model.count[...] = }')\n", + "\n", + "# jax.tree.map works directly on NNX Pytrees\n", + "doubled_model = jax.tree.map(lambda x: x * 2, model)\n", + "print(f'\\nmodel.w sum: {model.w.sum():.4f}')\n", + "print(f'doubled.w sum: {doubled_model.w.sum():.4f}')\n", + "\n", + "# jax.tree.leaves_with_path shows the full tree structure\n", + "print('\\nPytree leaves:')\n", + "for path, value in jax.tree.leaves_with_path(model):\n", + " print(f'{jax.tree_util.keystr(path)}: {value!r}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here `jax.tree.map` was first used create a new model with each leaf Array doubled, and then `jax.tree.flatten_with_path` was used to show how JAX sees the tree structure. Notice that because Variables are also JAX pytrees containing a single element (their inner value) we see `value` as part of the leaf path." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rngs\n", + "`nnx.Rngs` simplify managing [JAX PRNG state](https://jax.readthedocs.io/en/latest/random-numbers.html). It is itself an `nnx.Pytree` that stores a seed `key` and an incrementing `counter` in `Variable`s internally. By calling it, `Rngs` can produce new PRNG keys:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "key1 = Array((), dtype=key) overlaying:\n", + "[1797259609 2579123966]\n", + "key2 = Array((), dtype=key) overlaying:\n", + "[ 928981903 3453687069]\n", + "arr = Array([[ 1.2956359 , 1.3550105 , -0.40960556],\n", + " [-0.77188545, 0.38094172, 0.01888919]], dtype=float32)\n", + "\u001b[38;2;79;201;177mRngs\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # RngState: 2 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mdefault\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngStream\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # RngState: 2 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m,\n", + " \u001b[38;2;156;220;254mkey\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngKey\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (8 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray((), dtype=key) overlaying:\n", + " [0 0],\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mcount\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mRngCount\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(3, dtype=uint32),\n", + " \u001b[38;2;156;220;254mtag\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;207;144;120m'default'\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "rngs = nnx.Rngs(0) # seeded with 0\n", + "\n", + "key1 = rngs() # get a raw key\n", + "key2 = rngs() # different key (counter incremented)\n", + "arr = rngs.normal((2, 3)) # draw samples directly\n", + "\n", + "print(f'{key1 = }')\n", + "print(f'{key2 = }')\n", + "print(f'{arr = }')\n", + "print(rngs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we've seen so far, `Rngs` conveniently exposes every `jax.random.*` distribution as a method (e.g. `rngs.uniform(...)`, `rngs.normal(...)`) without requiring the `key` argument and returning different random values every time they are called, this highly simplifies the user experience. In general `Rngs` can hold multiple keys and counters in structures called `RngStream`s, above we see that the `default` stream is being used. For more information check out the [Randomness guide](https://flax.readthedocs.io/en/latest/guides/randomness.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Nested Modules\n", + "\n", + "Pytree subclasses compose naturally, you can assign one as an attribute of another to build nested models. The example below builds a simple `MLP` from two `Linear` layers:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 5)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class MLP(nnx.Pytree):\n", + " def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.din, self.dmid, self.dout = din, dmid, dout # static attributes\n", + " self.linear1 = Linear(din, dmid, rngs=rngs) # data attribute\n", + " self.linear2 = Linear(dmid, dout, rngs=rngs) # data attribute\n", + "\n", + " def __call__(self, x: jax.Array):\n", + " x = nnx.relu(self.linear1(x))\n", + " return self.linear2(x)\n", + "\n", + "mlp = MLP(2, 16, 5, rngs=nnx.Rngs(0))\n", + "y = mlp(jnp.ones((3, 2)))\n", + "print(f'{y.shape = }')\n", + "\n", + "nnx.display(mlp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because the entire model is a single pytree, all the `jax.tree.*` functions, JAX transforms, and NNX state APIs work on the full nested structure at once. For more info check out the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## JAX Transforms\n", + "\n", + "NNX models can be passed directly to JAX transforms like `jax.jit`. However, JAX transforms create pure functions, meaning that they won't propagate side effects such as Variable state updates back to the caller:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + } + ], + "source": [ + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "@jax.jit\n", + "def forward(model, x): # pure function\n", + " y = model(x)\n", + " return y\n", + "\n", + "y = forward(model, x)\n", + "\n", + "print(model.count[...]) # no state update" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here `count` was not updated because inside `jax.jit` new Variable copies are created so any updates inside will not be reflected outside. To propagate updates we can use two NNX helpers. `nnx.state(obj, *filters)` extracts the current state of all Variables in `obj` as a nested `State` dict; you can pass **filters** to select specific Variable types, for example `nnx.state(model, Count)` extracts only `Count` Variables (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for details). `nnx.update(obj, state)` writes a `State` back into the corresponding Variables of `obj`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + } + ], + "source": [ + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "\n", + "@jax.jit\n", + "def forward(model, x):\n", + " y = model(x)\n", + " return y, nnx.state(model, Count) # propagate state\n", + "\n", + "y, updates = forward(model, x)\n", + "nnx.update(model, updates) # apply state updates\n", + "\n", + "print(model.count[...]) # updated successfully" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we could've also chosen to return the entire `model` and replace its reference outside, however the use `nnx.state/update` is preferred as NNX promotes preserving existing Variable references." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training step with JAX transforms\n", + "\n", + "For a full training step we also need to differentiate with respect to some parameters while keeping the rest non-differentiable. `nnx.split` and `nnx.merge` let us partition and reconstruct the model. `nnx.split(obj, *filters)` returns a structure definition (`GraphDef`) followed by one `State` group per filter, where the catch-all filter `...` matches everything not yet matched by a previous filter (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for the full filter language). `nnx.merge(graphdef, *states)` reconstructs a copy of the object from its definition and state groups. We will use these to select the differentiable parameters when passing them to `jax.grad`.\n", + "\n", + "The example below shows a complete training step using raw JAX transforms. `nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and stores its internal state as Variables, providing a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@jax.jit\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params, nondiff):\n", + " nondiff = nnx.clone(nondiff) # refresh trace state\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss, nnx.state(model, Count) # propagate state\n", + "\n", + " grads, updates = jax.grad(loss_fn, has_aux=True)(params, nondiff)\n", + " nnx.update(model, updates)\n", + " optimizer.update(model, grads)\n", + "\n", + " return nnx.state((model, optimizer))\n", + "\n", + "updates = train_step(model, optimizer, x, y)\n", + "nnx.update((model, optimizer), updates)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few things to note. The state of the `model` and `optimizer` is extracted at once by packing them in a tuple (or any pytree), and `nnx.update` accepts the same structure. By default `jax.grad` differentiates with respect to the first positional argument only, `params` in our case. Finally, `nnx.clone` is needed because `jax.grad` passes non differentiable inputs (here `nondiff`) directly without tracing them, so we must manually clone them to refresh the trace state of their Variables - preventing tracer leakage. Omitting `nnx.clone` raises an error." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## NNX Transforms\n", + "\n", + "NNX transforms (`nnx.jit`, `nnx.grad`, ...) are thin wrappers over JAX transforms that provide the exact same APIs. Their main feature is **automatic state propagation**: the state of all input Variables is tracked and updated automatically behind the scenes. This removes the need for the `nnx.state/update` boilerplate and the use of `nnx.clone`:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@nnx.jit # automatic state propagation\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params, nondiff):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss\n", + "\n", + " grads = nnx.grad(loss_fn)(params, nondiff)\n", + " optimizer.update(model, grads)\n", + "\n", + "train_step(model, optimizer, x, y)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that `train_step` doesn't need to return anthing as `nnx.jit` propagates all Variable updates (model parameters, optimizer state, counts) automatically." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Graph Mode\n", + "\n", + "Certain programs are easier to express by sharing references between objets on different parts of a structure, however this is not compatible with JAX's pytree model. If we create a simple model that shares a reference to the same Variable in two different attributes, NNX transforms and most other APIs will raise an error as sharing can result in inconsistencies:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error: Variable at [0][0].b was already seen at [0][0].a. tree-mode jit does not support shared Variable references.\n" + ] + } + ], + "source": [ + "@nnx.dataclass\n", + "class Foo(nnx.Module):\n", + " a: nnx.Param\n", + " b: nnx.Param\n", + "\n", + "p = nnx.Param(jnp.array(1.0))\n", + "model = Foo(p, p) # shared Param\n", + "\n", + "@nnx.jit\n", + "def forward(model, x):\n", + " model.a[...] += 1.0\n", + " return model.a * x + model.b\n", + "\n", + "try:\n", + " forward(model, jnp.array(1.0))\n", + "except ValueError as e:\n", + " print(f'Error: {e}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, at the cost of some python overhead, `graph=True` can be passed to NNX APIs to enable **graph mode**. In graph mode, general graphs structures are allowed as long as they Variables are transformed consistently. We can fix the above example by enabling graph mode in `nnx.jit`:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y = 6.0, model.a[...] = 3.0, model.b[...] = 3.0\n" + ] + } + ], + "source": [ + "@nnx.jit(graph=True)\n", + "def forward(model, x):\n", + " model.a[...] += 1.0\n", + " return model.a * x + model.b\n", + "\n", + "y = forward(model, jnp.array(1.0))\n", + "\n", + "print(f'{y = !s}, {model.a[...] = !s}, {model.b[...] = !s}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hijax (experimental)\n", + "\n", + "JAX's experimental **Hijax** API allows custom mutable types whose state updates propagate automatically through JAX transforms. When enabled via `nnx.var_default(hijax=True)`, plain JAX transforms like `jax.jit` handle state propagation of `Variable`s without any manual `nnx.state` / `nnx.update` calls. As a bonus, in hijax mode Variables can also be passed as captures, further simplifying the loss function:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Count: 1 (4 B), Param: 10 (40 B), Total: 11 (44 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mdin\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m,\n", + " \u001b[38;2;156;220;254mdout\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m5\u001b[0m,\n", + " \u001b[38;2;156;220;254mw\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 10 (40 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m5\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mcount\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mCount\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mhijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "model.count[...] = Array(1, dtype=int32, weak_type=True)\n", + "optimizer.step[...] = Array(1, dtype=uint32)\n" + ] + } + ], + "source": [ + "with nnx.var_defaults(hijax=True): # enables Hijax Variables\n", + " x, y = jnp.ones((3, 2)), jnp.ones((3, 5))\n", + " model = Linear(2, 5, rngs=nnx.Rngs(0))\n", + " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "print(model) # display Hijax Variables\n", + "\n", + "@jax.jit # automatic state propagation\n", + "def train_step(model, optimizer, x, y):\n", + " # use same filter as Optimizer's `wrt`\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + "\n", + " def loss_fn(params):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " loss = jnp.mean((model(x) - y) ** 2)\n", + " return loss\n", + "\n", + " grads = jax.grad(loss_fn)(nnx.vars_as(params, hijax=False)) # disable hijax for param grads\n", + " optimizer.update(model, grads)\n", + "\n", + "train_step(model, optimizer, x, y)\n", + "\n", + "print(f'{model.count[...] = }')\n", + "print(f'{optimizer.step[...] = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As a temporary limitation, `jax.grad` does not yet handle mutable Hijax types. We work around this by converting `params` to regular Variables via `nnx.vars_as(params, hijax=False)` before passing them to `grad`. Hijax can also be enabled on a per-Variable basis by passing `hijax=True` to the constructor:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "v[...] = 1\n", + "v[...] = 2\n" + ] + } + ], + "source": [ + "v = nnx.Variable(jnp.array(1), hijax=True)\n", + "\n", + "@jax.jit\n", + "def inc(v):\n", + " v[...] += 1\n", + "\n", + "print(f'{v[...] = !s}')\n", + "inc(v)\n", + "print(f'{v[...] = !s}')" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs_nnx/nnx_basics_tree.md b/docs_nnx/nnx_basics_tree.md new file mode 100644 index 000000000..2bdc6f577 --- /dev/null +++ b/docs_nnx/nnx_basics_tree.md @@ -0,0 +1,319 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.13.8 +--- + +# NNX Basics + +NNX is a Neural Networks library for JAX. NNX provides the tools to structure modeling code as [JAX pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) so it can work with transforms, `jax.tree.*` utilities, and all standard JAX APIs. This guide covers the core concepts you need to get started. + +```{code-cell} ipython3 +from flax import nnx +import jax +import jax.numpy as jnp + +nnx.graphlib.set_graph_mode(False) +nnx.graphlib.set_graph_updates(False) +``` + +NNX's main build blocks are: + +- **`nnx.Pytree`**: Base class for pytree-compatible objects. Defines the tree structure of your model. +- **`nnx.Variable`**: Wraps array data and tracks mutable state. Subclasses like `nnx.Param` categorize different kinds of state. +- **State APIs** (`nnx.{state, split, merge, update}`): Extract, partition, reconstruct, and apply state updates. +- **NNX Transforms** (`nnx.{jit, grad, scan, ...}`): Thin wrappers over JAX transforms that automate state propagation. + ++++ + +## Pytrees and Variables + +`nnx.Pytree` and `nnx.Variable` are two orthogonal systems. **Pytrees** define the structure of your model as a JAX-compatible tree. **Variables** wrap array data and enable expressing state updates via in-place mutation. + +`Pytree`s are python objects that define its tree structure dynamically through its attributes, these are split into two categories: **Static attributes** (e.g. `int`, `str`) are embedded in the tree structure definition and are not traced by JAX. **Data attributes** (e.g. `nnx.Variable`, `jax.Array`) are the leaves of the tree and are traced by JAX. For more details see the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html). + +Here's a typical layer definition: + +```{code-cell} ipython3 +class Count(nnx.Variable): pass # custom Variable types + +class Linear(nnx.Pytree): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.din, self.dout = din, dout # static attributes + self.w = nnx.Param(rngs.uniform((din, dout))) # data attribute + self.count = Count(jnp.array(0)) # data attribute + + def __call__(self, x: jax.Array): + self.count[...] += 1 # inplace state updates + return x @ self.w # Variable are Array-like + +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +nnx.display(model) +``` + +> **Note:** Most user code uses `nnx.Module`, which is a subclass of `nnx.Pytree` with additional features such as sopport for metric reporting. + +As we can see above, Variables are array-like; they support arithmetic operators, indexing, and can be used directly in JAX expressions. You can update their value in-place using `variable[...] = new_value`. Since NNX Pytrees are standard JAX pytrees, you can use `jax.tree.*` functions directly on them: + +```{code-cell} ipython3 +x = jnp.ones((3, 2)) +y = model(x) +print(f'{y.shape = }, {model.count[...] = }') + +# jax.tree.map works directly on NNX Pytrees +doubled_model = jax.tree.map(lambda x: x * 2, model) +print(f'\nmodel.w sum: {model.w.sum():.4f}') +print(f'doubled.w sum: {doubled_model.w.sum():.4f}') + +# jax.tree.leaves_with_path shows the full tree structure +print('\nPytree leaves:') +for path, value in jax.tree.leaves_with_path(model): + print(f'{jax.tree_util.keystr(path)}: {value!r}') +``` + +Here `jax.tree.map` was first used create a new model with each leaf Array doubled, and then `jax.tree.flatten_with_path` was used to show how JAX sees the tree structure. Notice that because Variables are also JAX pytrees containing a single element (their inner value) we see `value` as part of the leaf path. + ++++ + +## Rngs +`nnx.Rngs` simplify managing [JAX PRNG state](https://jax.readthedocs.io/en/latest/random-numbers.html). It is itself an `nnx.Pytree` that stores a seed `key` and an incrementing `counter` in `Variable`s internally. By calling it, `Rngs` can produce new PRNG keys: + +```{code-cell} ipython3 +rngs = nnx.Rngs(0) # seeded with 0 + +key1 = rngs() # get a raw key +key2 = rngs() # different key (counter incremented) +arr = rngs.normal((2, 3)) # draw samples directly + +print(f'{key1 = }') +print(f'{key2 = }') +print(f'{arr = }') +print(rngs) +``` + +As we've seen so far, `Rngs` conveniently exposes every `jax.random.*` distribution as a method (e.g. `rngs.uniform(...)`, `rngs.normal(...)`) without requiring the `key` argument and returning different random values every time they are called, this highly simplifies the user experience. In general `Rngs` can hold multiple keys and counters in structures called `RngStream`s, above we see that the `default` stream is being used. For more information check out the [Randomness guide](https://flax.readthedocs.io/en/latest/guides/randomness.html). + ++++ + +## Nested Modules + +Pytree subclasses compose naturally, you can assign one as an attribute of another to build nested models. The example below builds a simple `MLP` from two `Linear` layers: + +```{code-cell} ipython3 +class MLP(nnx.Pytree): + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.din, self.dmid, self.dout = din, dmid, dout # static attributes + self.linear1 = Linear(din, dmid, rngs=rngs) # data attribute + self.linear2 = Linear(dmid, dout, rngs=rngs) # data attribute + + def __call__(self, x: jax.Array): + x = nnx.relu(self.linear1(x)) + return self.linear2(x) + +mlp = MLP(2, 16, 5, rngs=nnx.Rngs(0)) +y = mlp(jnp.ones((3, 2))) +print(f'{y.shape = }') + +nnx.display(mlp) +``` + +Because the entire model is a single pytree, all the `jax.tree.*` functions, JAX transforms, and NNX state APIs work on the full nested structure at once. For more info check out the [Pytree guide](https://flax.readthedocs.io/en/latest/guides/pytree.html). + ++++ + +## JAX Transforms + +NNX models can be passed directly to JAX transforms like `jax.jit`. However, JAX transforms create pure functions, meaning that they won't propagate side effects such as Variable state updates back to the caller: + +```{code-cell} ipython3 +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +@jax.jit +def forward(model, x): # pure function + y = model(x) + return y + +y = forward(model, x) + +print(model.count[...]) # no state update +``` + +Here `count` was not updated because inside `jax.jit` new Variable copies are created so any updates inside will not be reflected outside. To propagate updates we can use two NNX helpers. `nnx.state(obj, *filters)` extracts the current state of all Variables in `obj` as a nested `State` dict; you can pass **filters** to select specific Variable types, for example `nnx.state(model, Count)` extracts only `Count` Variables (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for details). `nnx.update(obj, state)` writes a `State` back into the corresponding Variables of `obj`. + +```{code-cell} ipython3 +model = Linear(2, 5, rngs=nnx.Rngs(0)) + +@jax.jit +def forward(model, x): + y = model(x) + return y, nnx.state(model, Count) # propagate state + +y, updates = forward(model, x) +nnx.update(model, updates) # apply state updates + +print(model.count[...]) # updated successfully +``` + +In this example we could've also chosen to return the entire `model` and replace its reference outside, however the use `nnx.state/update` is preferred as NNX promotes preserving existing Variable references. + ++++ + +### Training step with JAX transforms + +For a full training step we also need to differentiate with respect to some parameters while keeping the rest non-differentiable. `nnx.split` and `nnx.merge` let us partition and reconstruct the model. `nnx.split(obj, *filters)` returns a structure definition (`GraphDef`) followed by one `State` group per filter, where the catch-all filter `...` matches everything not yet matched by a previous filter (see the [Filters guide](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) for the full filter language). `nnx.merge(graphdef, *states)` reconstructs a copy of the object from its definition and state groups. We will use these to select the differentiable parameters when passing them to `jax.grad`. + +The example below shows a complete training step using raw JAX transforms. `nnx.Optimizer` wraps an [Optax](https://optax.readthedocs.io/) optimizer and stores its internal state as Variables, providing a simple `update(model, grads)` method that performs in-place updates to both the optimizer state and model parameters: + +```{code-cell} ipython3 +import optax + +x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) +model = Linear(2, 5, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@jax.jit +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params, nondiff): + nondiff = nnx.clone(nondiff) # refresh trace state + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss, nnx.state(model, Count) # propagate state + + grads, updates = jax.grad(loss_fn, has_aux=True)(params, nondiff) + nnx.update(model, updates) + optimizer.update(model, grads) + + return nnx.state((model, optimizer)) + +updates = train_step(model, optimizer, x, y) +nnx.update((model, optimizer), updates) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +A few things to note. The state of the `model` and `optimizer` is extracted at once by packing them in a tuple (or any pytree), and `nnx.update` accepts the same structure. By default `jax.grad` differentiates with respect to the first positional argument only, `params` in our case. Finally, `nnx.clone` is needed because `jax.grad` passes non differentiable inputs (here `nondiff`) directly without tracing them, so we must manually clone them to refresh the trace state of their Variables - preventing tracer leakage. Omitting `nnx.clone` raises an error. + ++++ + +## NNX Transforms + +NNX transforms (`nnx.jit`, `nnx.grad`, ...) are thin wrappers over JAX transforms that provide the exact same APIs. Their main feature is **automatic state propagation**: the state of all input Variables is tracked and updated automatically behind the scenes. This removes the need for the `nnx.state/update` boilerplate and the use of `nnx.clone`: + +```{code-cell} ipython3 +x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) +model = Linear(2, 5, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +@nnx.jit # automatic state propagation +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params, nondiff): + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss + + grads = nnx.grad(loss_fn)(params, nondiff) + optimizer.update(model, grads) + +train_step(model, optimizer, x, y) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +Notice that `train_step` doesn't need to return anthing as `nnx.jit` propagates all Variable updates (model parameters, optimizer state, counts) automatically. + ++++ + +## Graph Mode + +Certain programs are easier to express by sharing references between objets on different parts of a structure, however this is not compatible with JAX's pytree model. If we create a simple model that shares a reference to the same Variable in two different attributes, NNX transforms and most other APIs will raise an error as sharing can result in inconsistencies: + +```{code-cell} ipython3 +@nnx.dataclass +class Foo(nnx.Module): + a: nnx.Param + b: nnx.Param + +p = nnx.Param(jnp.array(1.0)) +model = Foo(p, p) # shared Param + +@nnx.jit +def forward(model, x): + model.a[...] += 1.0 + return model.a * x + model.b + +try: + forward(model, jnp.array(1.0)) +except ValueError as e: + print(f'Error: {e}') +``` + +However, at the cost of some python overhead, `graph=True` can be passed to NNX APIs to enable **graph mode**. In graph mode, general graphs structures are allowed as long as they Variables are transformed consistently. We can fix the above example by enabling graph mode in `nnx.jit`: + +```{code-cell} ipython3 +@nnx.jit(graph=True) +def forward(model, x): + model.a[...] += 1.0 + return model.a * x + model.b + +y = forward(model, jnp.array(1.0)) + +print(f'{y = !s}, {model.a[...] = !s}, {model.b[...] = !s}') +``` + +## Hijax (experimental) + +JAX's experimental **Hijax** API allows custom mutable types whose state updates propagate automatically through JAX transforms. When enabled via `nnx.var_default(hijax=True)`, plain JAX transforms like `jax.jit` handle state propagation of `Variable`s without any manual `nnx.state` / `nnx.update` calls. As a bonus, in hijax mode Variables can also be passed as captures, further simplifying the loss function: + +```{code-cell} ipython3 +with nnx.var_defaults(hijax=True): # enables Hijax Variables + x, y = jnp.ones((3, 2)), jnp.ones((3, 5)) + model = Linear(2, 5, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + +print(model) # display Hijax Variables + +@jax.jit # automatic state propagation +def train_step(model, optimizer, x, y): + # use same filter as Optimizer's `wrt` + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) + + def loss_fn(params): + model = nnx.merge(graphdef, params, nondiff) + loss = jnp.mean((model(x) - y) ** 2) + return loss + + grads = jax.grad(loss_fn)(nnx.vars_as(params, hijax=False)) # disable hijax for param grads + optimizer.update(model, grads) + +train_step(model, optimizer, x, y) + +print(f'{model.count[...] = }') +print(f'{optimizer.step[...] = }') +``` + +As a temporary limitation, `jax.grad` does not yet handle mutable Hijax types. We work around this by converting `params` to regular Variables via `nnx.vars_as(params, hijax=False)` before passing them to `grad`. Hijax can also be enabled on a per-Variable basis by passing `hijax=True` to the constructor: + +```{code-cell} ipython3 +v = nnx.Variable(jnp.array(1), hijax=True) + +@jax.jit +def inc(v): + v[...] += 1 + +print(f'{v[...] = !s}') +inc(v) +print(f'{v[...] = !s}') +``` diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 4004b02c0..37e2247af 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -25,7 +25,7 @@ from flax.nnx.pytreelib import Pytree from flax.nnx.variablelib import Variable -M = tp.TypeVar('M', bound=nnx.Module) +M = tp.TypeVar('M') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) class OptState(Variable):