Skip to content

Commit e2f9442

Browse files
committed
Add to Views guide
1 parent 103f4cc commit e2f9442

File tree

2 files changed

+159
-3
lines changed

2 files changed

+159
-3
lines changed

docs_nnx/guides/view.ipynb

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
"metadata": {},
77
"source": [
88
"# Model Views\n",
9-
"This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes."
9+
"This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes.\n",
10+
"\n",
11+
"NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original."
1012
]
1113
},
1214
{
@@ -320,11 +322,73 @@
320322
},
321323
{
322324
"cell_type": "markdown",
323-
"id": "1acbcc09",
325+
"id": "984b8eca",
324326
"metadata": {},
325327
"source": [
326328
"The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n",
327329
"\n",
330+
"## Using `with_vars`\n",
331+
"\n",
332+
"{func}`nnx.with_vars <flax.nnx.with_vars>` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `with_modules` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.\n",
333+
"\n",
334+
"The flags it controls are:\n",
335+
"\n",
336+
"- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.\n",
337+
"- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.\n",
338+
"- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.\n",
339+
"\n",
340+
"The `only` argument accepts a {doc}`Filter <filters_guide>` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original)."
341+
]
342+
},
343+
{
344+
"cell_type": "code",
345+
"execution_count": null,
346+
"id": "0938e1f6",
347+
"metadata": {},
348+
"outputs": [],
349+
"source": [
350+
"from flax import nnx\n",
351+
"import jax\n",
352+
"import jax.numpy as jnp\n",
353+
"\n",
354+
"class SimpleModel(nnx.Module):\n",
355+
" def __init__(self, rngs):\n",
356+
" self.linear = nnx.Linear(2, 3, rngs=rngs)\n",
357+
"\n",
358+
"model = SimpleModel(nnx.Rngs(0))\n",
359+
"\n",
360+
"# ref=True: expose Variable values as JAX refs so jax.tree.map can update them\n",
361+
"ref_model = nnx.with_vars(model, ref=True)\n",
362+
"ref_model = jax.tree.map(lambda x: x * 2, ref_model)\n",
363+
"\n",
364+
"# The original model's kernel is unchanged; ref_model has doubled values\n",
365+
"assert model.linear.kernel is not ref_model.linear.kernel"
366+
]
367+
},
368+
{
369+
"cell_type": "markdown",
370+
"id": "67c71d62",
371+
"metadata": {},
372+
"source": [
373+
"Use the `only` filter to convert only a subset of Variables:"
374+
]
375+
},
376+
{
377+
"cell_type": "code",
378+
"execution_count": null,
379+
"id": "254fe344",
380+
"metadata": {},
381+
"outputs": [],
382+
"source": [
383+
"# only convert Param variables, leave BatchStat variables unchanged\n",
384+
"ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)"
385+
]
386+
},
387+
{
388+
"cell_type": "markdown",
389+
"id": "1acbcc09",
390+
"metadata": {},
391+
"source": [
328392
"## Using `with_attributes`\n",
329393
"\n",
330394
"If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.with_modules`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged."
@@ -412,7 +476,33 @@
412476
"id": "bf521e45",
413477
"metadata": {},
414478
"source": [
415-
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`."
479+
"Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.\n",
480+
"\n",
481+
"## Other NNX views\n",
482+
"\n",
483+
"Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:\n",
484+
"\n",
485+
"- {func}`nnx.as_pure <flax.nnx.as_pure>` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.\n",
486+
"\n",
487+
" ```python\n",
488+
" _, state = nnx.split(model)\n",
489+
" pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain\n",
490+
" ```\n",
491+
"\n",
492+
"- {func}`nnx.as_abstract <flax.nnx.as_abstract>` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.\n",
493+
"\n",
494+
" ```python\n",
495+
" with jax.set_mesh(mesh):\n",
496+
" abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))\n",
497+
" abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars\n",
498+
" ```\n",
499+
"\n",
500+
"- {func}`nnx.with_rngs <flax.nnx.rnglib.with_rngs>` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.\n",
501+
"\n",
502+
" ```python\n",
503+
" # Split params stream into 4 keys (one per vmap replica); fork the rest\n",
504+
" vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)\n",
505+
" ```"
416506
]
417507
}
418508
],

docs_nnx/guides/view.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ jupytext:
1111
# Model Views
1212
This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes.
1313

14+
NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original.
15+
1416
```{code-cell}
1517
from flax import nnx
1618
@@ -216,6 +218,44 @@ print(nnx.view_info(model))
216218

217219
The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.
218220

221+
## Using `with_vars`
222+
223+
{func}`nnx.with_vars <flax.nnx.with_vars>` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `with_modules` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.
224+
225+
The flags it controls are:
226+
227+
- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.
228+
- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.
229+
- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.
230+
231+
The `only` argument accepts a {doc}`Filter <filters_guide>` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original).
232+
233+
```{code-cell}
234+
from flax import nnx
235+
import jax
236+
import jax.numpy as jnp
237+
238+
class SimpleModel(nnx.Module):
239+
def __init__(self, rngs):
240+
self.linear = nnx.Linear(2, 3, rngs=rngs)
241+
242+
model = SimpleModel(nnx.Rngs(0))
243+
244+
# ref=True: expose Variable values as JAX refs so jax.tree.map can update them
245+
ref_model = nnx.with_vars(model, ref=True)
246+
ref_model = jax.tree.map(lambda x: x * 2, ref_model)
247+
248+
# The original model's kernel is unchanged; ref_model has doubled values
249+
assert model.linear.kernel is not ref_model.linear.kernel
250+
```
251+
252+
Use the `only` filter to convert only a subset of Variables:
253+
254+
```{code-cell}
255+
# only convert Param variables, leave BatchStat variables unchanged
256+
ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)
257+
```
258+
219259
## Using `with_attributes`
220260

221261
If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.with_modules`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged.
@@ -280,3 +320,29 @@ print(noisy_model)s
280320
```
281321

282322
Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.
323+
324+
## Other NNX views
325+
326+
Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:
327+
328+
- {func}`nnx.as_pure <flax.nnx.as_pure>` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.
329+
330+
```python
331+
_, state = nnx.split(model)
332+
pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain
333+
```
334+
335+
- {func}`nnx.as_abstract <flax.nnx.as_abstract>` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.
336+
337+
```python
338+
with jax.set_mesh(mesh):
339+
abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))
340+
abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars
341+
```
342+
343+
- {func}`nnx.with_rngs <flax.nnx.rnglib.with_rngs>` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.
344+
345+
```python
346+
# Split params stream into 4 keys (one per vmap replica); fork the rest
347+
vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)
348+
```

0 commit comments

Comments
 (0)