Skip to content

Commit fbf5dbe

Browse files
committed
Make all NNX view functions use recursive_map or nnx.map (to handle graphs)
1 parent 96e05c5 commit fbf5dbe

File tree

6 files changed

+380
-29
lines changed

6 files changed

+380
-29
lines changed

docs_nnx/guides/view.ipynb

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
"metadata": {},
77
"source": [
88
"# Model Views\n",
9-
"This guide covers how to use NNX \"Views\", which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, an NNX View that overwrites module attributes."
9+
"This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes.\n",
10+
"\n",
11+
"NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original."
1012
]
1113
},
1214
{
@@ -48,7 +50,7 @@
4850
"\n",
4951
"Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model` object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten.\n",
5052
"\n",
51-
"`nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below."
53+
"`nnx.with_modules` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below."
5254
]
5355
},
5456
{
@@ -115,7 +117,7 @@
115117
"source": [
116118
"From the model display, we can see that `Dropout` has `deterministic == False`, suggesting that the model is in training mode. In order to know this, we had to display the model and/or know that `Dropout` is set to training mode by default. It is not clear what state the model is in just by looking at the code without additional inspection. We instead want to be very explicit about what state the model is in. \n",
117119
"\n",
118-
"This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below."
120+
"This is where `nnx.with_modules` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below."
119121
]
120122
},
121123
{
@@ -141,7 +143,7 @@
141143
"id": "5c1ee1db",
142144
"metadata": {},
143145
"source": [
144-
"## Example with `nnx.view`"
146+
"## Example with `nnx.with_modules`"
145147
]
146148
},
147149
{
@@ -216,7 +218,7 @@
216218
"metadata": {},
217219
"source": [
218220
"## Getting information with `nnx.view_info`\n",
219-
"To see more information about the options for `nnx.view`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions."
221+
"To see more information about the options for `nnx.with_modules`, we can use the `nnx.view_info` function to display information about the arguments. This will display each submodule which contains a `set_view` method. It also provides information about the keyword arguments accepted by each submodule, including type information, default values, and docstring descriptions."
220222
]
221223
},
222224
{
@@ -234,9 +236,9 @@
234236
"id": "47479be6",
235237
"metadata": {},
236238
"source": [
237-
"## Writing modules compatible with `nnx.view`\n",
239+
"## Writing modules compatible with `nnx.with_modules`\n",
238240
"\n",
239-
"You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.view` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about.\n",
241+
"You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.with_modules` is called, it traverses the module tree and calls `set_view` on every submodule that defines one. `nnx.with_modules` inspects the signature of each `set_view` method and only passes the keyword arguments that match the method's declared parameters. This means each module only receives the kwargs it cares about.\n",
240242
"\n",
241243
"Your `set_view` method should follow these conventions:\n",
242244
"\n",
@@ -320,14 +322,76 @@
320322
},
321323
{
322324
"cell_type": "markdown",
323-
"id": "1acbcc09",
325+
"id": "984b8eca",
324326
"metadata": {},
325327
"source": [
326328
"The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n",
327329
"\n",
330+
"## Using `with_vars`\n",
331+
"\n",
332+
"{func}`nnx.with_vars <flax.nnx.with_vars>` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `with_modules` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.\n",
333+
"\n",
334+
"The flags it controls are:\n",
335+
"\n",
336+
"- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.\n",
337+
"- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.\n",
338+
"- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.\n",
339+
"\n",
340+
"The `only` argument accepts a {doc}`Filter <filters_guide>` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original)."
341+
]
342+
},
343+
{
344+
"cell_type": "code",
345+
"execution_count": null,
346+
"id": "0938e1f6",
347+
"metadata": {},
348+
"outputs": [],
349+
"source": [
350+
"from flax import nnx\n",
351+
"import jax\n",
352+
"import jax.numpy as jnp\n",
353+
"\n",
354+
"class SimpleModel(nnx.Module):\n",
355+
" def __init__(self, rngs):\n",
356+
" self.linear = nnx.Linear(2, 3, rngs=rngs)\n",
357+
"\n",
358+
"model = SimpleModel(nnx.Rngs(0))\n",
359+
"\n",
360+
"# ref=True: expose Variable values as JAX refs so jax.tree.map can update them\n",
361+
"ref_model = nnx.with_vars(model, ref=True)\n",
362+
"ref_model = jax.tree.map(lambda x: x * 2, ref_model)\n",
363+
"\n",
364+
"# The original model's kernel is unchanged; ref_model has doubled values\n",
365+
"assert model.linear.kernel is not ref_model.linear.kernel"
366+
]
367+
},
368+
{
369+
"cell_type": "markdown",
370+
"id": "67c71d62",
371+
"metadata": {},
372+
"source": [
373+
"Use the `only` filter to convert only a subset of Variables:"
374+
]
375+
},
376+
{
377+
"cell_type": "code",
378+
"execution_count": null,
379+
"id": "254fe344",
380+
"metadata": {},
381+
"outputs": [],
382+
"source": [
383+
"# only convert Param variables, leave BatchStat variables unchanged\n",
384+
"ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)"
385+
]
386+
},
387+
{
388+
"cell_type": "markdown",
389+
"id": "1acbcc09",
390+
"metadata": {},
391+
"source": [
328392
"## Using `with_attributes`\n",
329393
"\n",
330-
"If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.view`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged."
394+
"If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.with_modules`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged."
331395
]
332396
},
333397
{
@@ -412,7 +476,33 @@
412476
"id": "bf521e45",
413477
"metadata": {},
414478
"source": [
415-
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`."
479+
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.\n",
480+
"\n",
481+
"## Other NNX views\n",
482+
"\n",
483+
"Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:\n",
484+
"\n",
485+
"- {func}`nnx.as_pure <flax.nnx.as_pure>` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.\n",
486+
"\n",
487+
" ```python\n",
488+
" _, state = nnx.split(model)\n",
489+
" pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain\n",
490+
" ```\n",
491+
"\n",
492+
"- {func}`nnx.as_abstract <flax.nnx.as_abstract>` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.\n",
493+
"\n",
494+
" ```python\n",
495+
" with jax.set_mesh(mesh):\n",
496+
" abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))\n",
497+
" abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars\n",
498+
" ```\n",
499+
"\n",
500+
"- {func}`nnx.with_rngs <flax.nnx.rnglib.with_rngs>` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.\n",
501+
"\n",
502+
" ```python\n",
503+
" # Split params stream into 4 keys (one per vmap replica); fork the rest\n",
504+
" vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)\n",
505+
" ```"
416506
]
417507
}
418508
],

0 commit comments

Comments
 (0)