Skip to content

Commit 357b298

Browse files
author
Flax Authors
committed
Merge pull request #5294 from google:nnx-view
PiperOrigin-RevId: 878004673
2 parents e067f39 + d5ce436 commit 357b298

File tree

5 files changed

+263
-318
lines changed

5 files changed

+263
-318
lines changed

docs_nnx/guides/view.ipynb

Lines changed: 126 additions & 167 deletions
Large diffs are not rendered by default.

docs_nnx/guides/view.md

Lines changed: 96 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@ jupytext:
88
jupytext_version: 1.13.8
99
---
1010

11-
# Using `nnx.view`
12-
This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this method works, refer to the following example:
11+
# Model Views
12+
This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:
1313

1414
```{code-cell} ipython3
1515
from flax import nnx
1616
1717
# example model with different train/eval behavior
1818
rngs = nnx.Rngs(0)
19-
model = nnx.Sequential(nnx.Linear(2, 4, rngs=rngs), nnx.BatchNorm(4, rngs=rngs), nnx.Dropout(0.1))
19+
model = nnx.Sequential(
20+
nnx.Linear(2, 4, rngs=rngs), nnx.BatchNorm(4, rngs=rngs), nnx.Dropout(0.1)
21+
)
2022
2123
# set train and eval modes
2224
train_model = nnx.view(model, deterministic=False, use_running_average=False)
@@ -35,9 +37,9 @@ print(nnx.view_info(model))
3537

3638
## Motivation
3739

38-
Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern for doing this is to have a single `model` and use a method like `.train()` before running training code. This requires that the programmer remembers to call these functions in many places in the code and can make the code less readable. Moreover, it can be cumbersome to trace through someone else's code trying to infer the state of a model.
40+
Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model` object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten.
3941

40-
An alternative to this pattern is to declare the different model states at the beginning of the code. The `nnx.view` method lets you apply multiple model configurations at the beginning of your code and then just use these later without having to call functions like `.train()` or `.eval()`. We demonstrate this with a simple example below.
42+
`nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below.
4143

4244
```{code-cell} ipython3
4345
import jax
@@ -46,19 +48,25 @@ import matplotlib.pyplot as plt
4648
4749
in_dim, hidden_dim, out_dim = 16, 32, 2
4850
51+
4952
class MyModel(nnx.Module):
50-
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, dropout_rate: float, *, rngs: nnx.Rngs):
51-
self.lin1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs)
52-
self.do = nnx.Dropout(dropout_rate)
53-
self.bn = nnx.BatchNorm(hidden_dim, rngs=rngs)
54-
self.lin2 = nnx.Linear(hidden_dim, out_dim, rngs=rngs)
55-
56-
def __call__(self, x, *, rngs):
57-
x = self.lin1(x)
58-
x = self.bn(x)
59-
x = nnx.relu(self.do(x, rngs=rngs))
60-
x = self.lin2(x)
61-
return x
53+
def __init__(
54+
self,
55+
in_dim: int,
56+
hidden_dim: int,
57+
out_dim: int,
58+
dropout_rate: float,
59+
*,
60+
rngs: nnx.Rngs,
61+
):
62+
self.lin1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs)
63+
self.do = nnx.Dropout(dropout_rate)
64+
self.bn = nnx.BatchNorm(hidden_dim, rngs=rngs)
65+
self.lin2 = nnx.Linear(hidden_dim, out_dim, rngs=rngs)
66+
67+
def __call__(self, x, *, rngs=None):
68+
x = nnx.relu(self.do(self.bn(self.lin1(x)), rngs=rngs))
69+
return self.lin2(x)
6270
```
6371

6472
Lets take a look at the model to see what is going on.
@@ -81,73 +89,55 @@ train_model = nnx.view(model, deterministic=False)
8189
eval_model = nnx.view(model, deterministic=True)
8290
8391
# weights are references to the same data
84-
assert train_model.lin1.kernel[...] is eval_model.lin1.kernel[...]
92+
assert train_model.lin1.kernel is eval_model.lin1.kernel
8593
8694
# Dropout.deterministic is different in each model
8795
assert train_model.do.deterministic is False
8896
assert eval_model.do.deterministic is True
8997
```
9098

91-
Lets see what implications this has for our code design. We set up some simple training and evaluation functions below.
99+
## Example with `nnx.view`
100+
101+
+++
102+
103+
We first set up data generators and define train/eval step functions. The `train_step` receives an `nnx.Rngs` object for dropout randomness, while `eval_step` doesn't since dropout is disabled in `eval_model`.
92104

93105
```{code-cell} ipython3
94106
ndata, batch_size, total_epochs, lr = 2048, 32, 100, 1e-3
95-
x = jax.random.normal(jax.random.key(0), (ndata, in_dim))
96-
y = jax.random.normal(jax.random.key(0), (ndata, out_dim))
107+
rngs = nnx.Rngs(0)
108+
x = rngs.normal((ndata, in_dim))
109+
y = rngs.normal((ndata, out_dim))
110+
97111
98112
@nnx.jit
99113
def train_step(model, optimizer, x, y, rngs):
100-
def loss_fn(model):
101-
y_pred = model(x, rngs=rngs)
102-
return ((y_pred - y) ** 2).mean()
114+
def loss_fn(model, rngs):
115+
return ((model(x, rngs=rngs) - y) ** 2).mean()
103116
104-
loss, grads = nnx.value_and_grad(loss_fn)(model)
117+
grads = nnx.grad(loss_fn)(model, rngs)
105118
optimizer.update(model, grads)
106-
return loss
119+
107120
108121
@nnx.jit
109122
def eval_step(model, x, y):
110-
return ((model(x, rngs=None) - y) ** 2).mean()
123+
return ((model(x) - y) ** 2).mean()
111124
```
112125

113-
## Example with `nnx.view`
126+
Now we create `train_model` and `eval_model` views up front. During the training loop we simply use the appropriate view — there is no need to call `.train()` or `.eval()`, and it is always clear from the code which mode the model is in.
114127

115128
```{code-cell} ipython3
116-
model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=nnx.Rngs(0))
117-
train_model = nnx.view(model, deterministic=False) # create training view
118-
eval_model = nnx.view(model, deterministic=True) # create evaluation view
129+
model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs)
119130
optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)
120-
key = jax.random.key(0)
131+
train_model = nnx.view(model, deterministic=False) # training view
132+
eval_model = nnx.view(model, deterministic=True) # eval view
121133
122-
eval_results = [None] * total_epochs
134+
eval_results = []
123135
for epoch in range(total_epochs):
124-
for i in range(ndata // batch_size):
125-
sl = slice(i*batch_size,(i+1)*batch_size)
126-
key, subkey = jax.random.split(key)
127-
train_step(train_model, optimizer, x[sl], y[sl], subkey) # use train_model
136+
for i in range(ndata // batch_size):
137+
idx = slice(i * batch_size, (i + 1) * batch_size)
138+
train_step(train_model, optimizer, x[idx], y[idx], rngs) # use train_model
128139
129-
eval_results[epoch] = eval_step(eval_model, x, y) # use eval_model
130-
plt.plot(eval_results)
131-
plt.show()
132-
```
133-
134-
## Example with Old API
135-
136-
```{code-cell} ipython3
137-
model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=nnx.Rngs(0))
138-
optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)
139-
key = jax.random.key(0)
140-
141-
eval_results = [None] * total_epochs
142-
for epoch in range(total_epochs):
143-
model.set_attributes(deterministic=False) # set to train mode
144-
for i in range(ndata // batch_size):
145-
sl = slice(i*batch_size,(i+1)*batch_size)
146-
key, subkey = jax.random.split(key)
147-
train_step(model, optimizer, x[sl], y[sl], subkey)
148-
149-
model.set_attributes(deterministic=True) # set to eval mode
150-
eval_results[epoch] = eval_step(model, x, y)
140+
eval_results.append(eval_step(eval_model, x, y)) # use eval_model
151141
plt.plot(eval_results)
152142
plt.show()
153143
```
@@ -160,68 +150,71 @@ print(nnx.view_info(model))
160150
```
161151

162152
## Writing modules compatible with `nnx.view`
163-
To implement a module that is compatible with `nnx.view` you just need do define a `set_view` class method. This method should
164-
1. Include input kwargs, type annotations, and a default value of `None`.
165-
2. Only update the mode if the kwarg is not `None`.
166-
3. Include input `**kwargs` for additional keyword arguments used in `nnx.view`. These should not be used by `set_view`.
167-
4. Include a google-style docstring (for parsing with `nnx.view_info`).
168-
5. Return kwargs for identifying unused kwargs.
169-
170-
This will look like
153+
154+
You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one, forwarding the keyword arguments you passed.
155+
156+
Your `set_view` method should follow these conventions:
157+
158+
1. **Accept keyword arguments with `None` defaults.** Each kwarg represents a configurable mode for this module. A `None` default means "leave unchanged", so views only override the modes you explicitly set.
159+
2. **Only update the attribute when the kwarg is not `None`.** This ensures that unrelated views don't accidentally reset each other's settings.
160+
3. **Accept `**kwargs` and return it.** This lets other submodules in the tree consume their own keyword arguments without raising errors about unexpected kwargs.
161+
4. **Include a Google-style docstring.** The `nnx.view_info` function parses these docstrings to display human-readable information about available view options.
162+
163+
The general pattern looks like this:
164+
171165
```python
172166
class MyLayer(nnx.Module):
173167
...
174168

175-
def set_view(self, kwarg1: type1 = default1, ..., kwargN: typeN = defaultN, **kwargs) -> dict:
176-
"""Module docstring following Google-style docstrings"""
177-
# logic to update the mode
169+
def set_view(self, kwarg1: type1 = None, ..., kwargN: typeN = None, **kwargs) -> dict:
170+
"""Description of the module's configurable modes.
171+
172+
Args:
173+
kwarg1: description of kwarg1.
174+
...
175+
kwargN: description of kwargN.
176+
"""
177+
if kwarg1 is not None:
178+
self.kwarg1 = kwarg1
179+
...
178180
return kwargs
179181
```
180182

181-
Consider the following example
183+
Here is a concrete example — a `PrintLayer` that can be toggled to print a message during its forward pass:
182184

183185
```{code-cell} ipython3
184186
class PrintLayer(nnx.Module):
185-
def __init__(self, msg: str, *, rngs: nnx.Rngs):
186-
self.print_msg = None
187-
self.msg = msg
187+
def __init__(self, msg: str | None = None):
188+
self.msg = msg
188189
189-
def __call__(self, *args, **kwargs):
190-
if self.print_msg:
191-
print(self.msg)
192-
193-
def set_view(self, print_msg: bool | None = None, **kwargs) -> dict:
194-
"""Example set_view docstring. This follows Google style docstrings.
190+
def __call__(self, *args, **kwargs):
191+
if self.msg:
192+
print(self.msg)
195193
196-
Args:
197-
print_msg: bool indicating if a message should be printed.
198-
If True, the `__call__` method prints the message.
199-
"""
200-
if print_msg is not None:
201-
self.print_msg = print_msg
202-
return kwargs
203-
204-
l = PrintLayer("Hello, World!", rngs=nnx.Rngs(0))
205-
l_print = nnx.view(l, print_msg=True)
206-
```
194+
def set_view(self, msg: bool | None = None, **kwargs) -> dict:
195+
"""Example set_view docstring. This follows Google style docstrings.
207196
208-
```{code-cell} ipython3
209-
# print the default view
210-
print(l)
197+
Args:
198+
msg: bool indicating if a message should be printed.
199+
If True, the `__call__` method prints the message.
200+
"""
201+
if msg is not None:
202+
self.msg = msg
203+
return kwargs
211204
212-
# Nothing printed from call method
213-
l()
214-
```
215205
216-
```{code-cell} ipython3
217-
# print the l_print view
218-
print(l_print)
206+
model = PrintLayer()
207+
model_print = nnx.view(model, msg='Hello, World!')
219208
220-
# Prints "Hello, World!" from the call method
221-
l_print()
209+
model() # nothing printed
210+
model_print() # prints "Hello, World!"
222211
```
223212

213+
We can use `nnx.view_info` to inspect what view options `PrintLayer` exposes. This is especially handy when working with unfamiliar models — it lists every submodule that defines `set_view`, along with the accepted kwargs, their types, defaults, and docstring descriptions.
214+
224215
```{code-cell} ipython3
225216
# Display the information for nnx.view
226-
print(nnx.view_info(l))
217+
print(nnx.view_info(model))
227218
```
219+
220+
The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.

docs_nnx/guides_basic.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Basic Guides
77

88
guides/pytree
99
guides/transforms
10+
guides/view
1011
guides/filters_guide
1112
guides/randomness
1213
guides/checkpointing

0 commit comments

Comments
 (0)