Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 108 additions & 18 deletions docs_nnx/guides/view.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
"metadata": {},
"source": [
"# Model Views\n",
"This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:"
"This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes.\n",
"\n",
"NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original."
]
},
{
Expand All @@ -25,8 +27,8 @@
")\n",
"\n",
"# set train and eval modes\n",
"train_model = nnx.view(model, deterministic=False, use_running_average=False)\n",
"eval_model = nnx.view(model, deterministic=True, use_running_average=True)\n",
"train_model = nnx.with_modules(model, deterministic=False, use_running_average=False)\n",
"eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True)\n",
"\n",
"# Can see deterministic is different between train_model and eval_model\n",
"assert train_model.layers[2].deterministic == False\n",
Expand All @@ -35,7 +37,7 @@
"# Weights are shared between the models\n",
"assert train_model.layers[0].kernel is eval_model.layers[0].kernel\n",
"\n",
"# Print information about kwargs for nnx.view with nnx.view_info\n",
"# Print information about kwargs for nnx.with_modules with nnx.view_info\n",
"print(nnx.view_info(model))"
]
},
Expand All @@ -48,7 +50,7 @@
"\n",
"Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model` object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten.\n",
"\n",
"`nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below."
"`nnx.with_modules` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below."
]
},
{
Expand Down Expand Up @@ -115,7 +117,7 @@
"source": [
"From the model display, we can see that `Dropout` has `deterministic == False`, suggesting that the model is in training mode. In order to know this, we had to display the model and/or know that `Dropout` is set to training mode by default. It is not clear what state the model is in just by looking at the code without additional inspection. We instead want to be very explicit about what state the model is in. \n",
"\n",
"This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below."
"This is where `nnx.with_modules` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below."
]
},
{
Expand All @@ -125,8 +127,8 @@
"metadata": {},
"outputs": [],
"source": [
"train_model = nnx.view(model, deterministic=False)\n",
"eval_model = nnx.view(model, deterministic=True)\n",
"train_model = nnx.with_modules(model, deterministic=False)\n",
"eval_model = nnx.with_modules(model, deterministic=True)\n",
"\n",
"# weights are references to the same data\n",
"assert train_model.lin1.kernel is eval_model.lin1.kernel\n",
Expand All @@ -141,7 +143,7 @@
"id": "5c1ee1db",
"metadata": {},
"source": [
"## Example with `nnx.view`"
"## Example with `nnx.with_modules`"
]
},
{
Expand Down Expand Up @@ -196,8 +198,8 @@
"source": [
"model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs)\n",
"optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)\n",
"train_model = nnx.view(model, deterministic=False) # training view\n",
"eval_model = nnx.view(model, deterministic=True) # eval view\n",
"train_model = nnx.with_modules(model, deterministic=False) # training view\n",
"eval_model = nnx.with_modules(model, deterministic=True) # eval view\n",
"\n",
"eval_results = []\n",
"for epoch in range(total_epochs):\n",
Expand All @@ -216,7 +218,7 @@
"metadata": {},
"source": [
"## Getting information with `nnx.view_info`\n",
"To see more information about the options for `nnx.view`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions."
"To see more information about the options for `nnx.with_modules`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions."
]
},
{
Expand All @@ -234,9 +236,9 @@
"id": "47479be6",
"metadata": {},
"source": [
"## Writing modules compatible with `nnx.view`\n",
"## Writing modules compatible with `nnx.with_modules`\n",
"\n",
"You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.view` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about.\n",
"You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.with_modules` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.with_modules` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about.\n",
"\n",
"Your `set_view` method should follow these conventions:\n",
"\n",
Expand Down Expand Up @@ -293,7 +295,7 @@
"\n",
"\n",
"model = PrintLayer()\n",
"model_print = nnx.view(model, msg='Hello, World!')\n",
"model_print = nnx.with_modules(model, msg='Hello, World!')\n",
"\n",
"model() # nothing printed\n",
"model_print() # prints \"Hello, World!\""
Expand All @@ -320,14 +322,76 @@
},
{
"cell_type": "markdown",
"id": "1acbcc09",
"id": "984b8eca",
"metadata": {},
"source": [
"The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n",
"\n",
"## Using `with_vars`\n",
"\n",
"{func}`nnx.with_vars <flax.nnx.with_vars>` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `with_modules` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.\n",
"\n",
"The flags it controls are:\n",
"\n",
"- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.\n",
"- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.\n",
"- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.\n",
"\n",
"The `only` argument accepts a {doc}`Filter <filters_guide>` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0938e1f6",
"metadata": {},
"outputs": [],
"source": [
"from flax import nnx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"class SimpleModel(nnx.Module):\n",
" def __init__(self, rngs):\n",
" self.linear = nnx.Linear(2, 3, rngs=rngs)\n",
"\n",
"model = SimpleModel(nnx.Rngs(0))\n",
"\n",
"# ref=True: expose Variable values as JAX refs so jax.tree.map can update them\n",
"ref_model = nnx.with_vars(model, ref=True)\n",
"ref_model = jax.tree.map(lambda x: x * 2, ref_model)\n",
"\n",
"# The original model's kernel is unchanged; ref_model has doubled values\n",
"assert model.linear.kernel is not ref_model.linear.kernel"
]
},
{
"cell_type": "markdown",
"id": "67c71d62",
"metadata": {},
"source": [
"Use the `only` filter to convert only a subset of Variables:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "254fe344",
"metadata": {},
"outputs": [],
"source": [
"# only convert Param variables, leave BatchStat variables unchanged\n",
"ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)"
]
},
{
"cell_type": "markdown",
"id": "1acbcc09",
"metadata": {},
"source": [
"## Using `with_attributes`\n",
"\n",
"If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged."
"If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.with_modules`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged."
]
},
{
Expand Down Expand Up @@ -412,7 +476,33 @@
"id": "bf521e45",
"metadata": {},
"source": [
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`."
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.\n",
"\n",
"## Other NNX views\n",
"\n",
"Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:\n",
"\n",
"- {func}`nnx.as_pure <flax.nnx.as_pure>` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.\n",
"\n",
" ```python\n",
" _, state = nnx.split(model)\n",
" pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain\n",
" ```\n",
"\n",
"- {func}`nnx.as_abstract <flax.nnx.as_abstract>` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.\n",
"\n",
" ```python\n",
" with jax.set_mesh(mesh):\n",
" abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))\n",
" abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars\n",
" ```\n",
"\n",
"- {func}`nnx.with_rngs <flax.nnx.rnglib.with_rngs>` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.\n",
"\n",
" ```python\n",
" # Split params stream into 4 keys (one per vmap replica); fork the rest\n",
" vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)\n",
" ```"
]
}
],
Expand Down
Loading
Loading