You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this method works, refer to the following example:
11
+
# Model Views
12
+
This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:
13
13
14
14
```{code-cell} ipython3
15
15
from flax import nnx
16
16
17
17
# example model with different train/eval behavior
18
18
rngs = nnx.Rngs(0)
19
-
model = nnx.Sequential(nnx.Linear(2, 4, rngs=rngs), nnx.BatchNorm(4, rngs=rngs), nnx.Dropout(0.1))
Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern for doing this is to have a single `model`and use a method like `.train()` before running training code. This requires that the programmer remembers to call these functions in many places in the code and can make the code less readable. Moreover, it can be cumbersome to trace through someone else's code trying to infer the state of a model.
40
+
Some layers in ML inherently involve state. Consider for example the `nnx.Dropout` layer, which behaves differently during training and evaluation. In these different scenarios, we need a simple way to ensure that the model behaves as intended to avoid silent bugs. A common pattern in other frameworks is to mutate a single `model`object to switch between training and evaluation modes. This requires the programmer to remember to toggle modes in many places throughout the code, which can hurt readability and lead to subtle bugs when a mode switch is forgotten.
39
41
40
-
An alternative to this pattern is to declare the different model states at the beginning of the code. The `nnx.view` method lets you apply multiple model configurations at the beginning of your code and then just use these later without having to call functions like `.train()` or `.eval()`. We demonstrate this with a simple example below.
42
+
`nnx.view` offers a cleaner alternative: you declare the different model configurations once at the beginning of your code and then simply use the appropriate view wherever needed. Each view shares the same underlying weights, so parameter updates are automatically reflected across all views. We demonstrate this with a simple example below.
41
43
42
44
```{code-cell} ipython3
43
45
import jax
@@ -46,19 +48,25 @@ import matplotlib.pyplot as plt
assert train_model.lin1.kernel[...] is eval_model.lin1.kernel[...]
92
+
assert train_model.lin1.kernel is eval_model.lin1.kernel
85
93
86
94
# Dropout.deterministic is different in each model
87
95
assert train_model.do.deterministic is False
88
96
assert eval_model.do.deterministic is True
89
97
```
90
98
91
-
Lets see what implications this has for our code design. We set up some simple training and evaluation functions below.
99
+
## Example with `nnx.view`
100
+
101
+
+++
102
+
103
+
We first set up data generators and define train/eval step functions. The `train_step` receives an `nnx.Rngs` object for dropout randomness, while `eval_step` doesn't since dropout is disabled in `eval_model`.
92
104
93
105
```{code-cell} ipython3
94
106
ndata, batch_size, total_epochs, lr = 2048, 32, 100, 1e-3
95
-
x = jax.random.normal(jax.random.key(0), (ndata, in_dim))
96
-
y = jax.random.normal(jax.random.key(0), (ndata, out_dim))
107
+
rngs = nnx.Rngs(0)
108
+
x = rngs.normal((ndata, in_dim))
109
+
y = rngs.normal((ndata, out_dim))
110
+
97
111
98
112
@nnx.jit
99
113
def train_step(model, optimizer, x, y, rngs):
100
-
def loss_fn(model):
101
-
y_pred = model(x, rngs=rngs)
102
-
return ((y_pred - y) ** 2).mean()
114
+
def loss_fn(model, rngs):
115
+
return ((model(x, rngs=rngs) - y) ** 2).mean()
103
116
104
-
loss, grads = nnx.value_and_grad(loss_fn)(model)
117
+
grads = nnx.grad(loss_fn)(model, rngs)
105
118
optimizer.update(model, grads)
106
-
return loss
119
+
107
120
108
121
@nnx.jit
109
122
def eval_step(model, x, y):
110
-
return ((model(x, rngs=None) - y) ** 2).mean()
123
+
return ((model(x) - y) ** 2).mean()
111
124
```
112
125
113
-
## Example with `nnx.view`
126
+
Now we create `train_model` and `eval_model` views up front. During the training loop we simply use the appropriate view — there is no need to call `.train()` or `.eval()`, and it is always clear from the code which mode the model is in.
114
127
115
128
```{code-cell} ipython3
116
-
model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=nnx.Rngs(0))
117
-
train_model = nnx.view(model, deterministic=False) # create training view
model.set_attributes(deterministic=True) # set to eval mode
150
-
eval_results[epoch] = eval_step(model, x, y)
140
+
eval_results.append(eval_step(eval_model, x, y)) # use eval_model
151
141
plt.plot(eval_results)
152
142
plt.show()
153
143
```
@@ -160,68 +150,71 @@ print(nnx.view_info(model))
160
150
```
161
151
162
152
## Writing modules compatible with `nnx.view`
163
-
To implement a module that is compatible with `nnx.view` you just need do define a `set_view` class method. This method should
164
-
1. Include input kwargs, type annotations, and a default value of `None`.
165
-
2. Only update the mode if the kwarg is not `None`.
166
-
3. Include input `**kwargs` for additional keyword arguments used in `nnx.view`. These should not be used by `set_view`.
167
-
4. Include a google-style docstring (for parsing with `nnx.view_info`).
168
-
5. Return kwargs for identifying unused kwargs.
169
-
170
-
This will look like
153
+
154
+
You can make any custom module work with `nnx.view` by defining a `set_view` method. When `nnx.view` is called, it traverses the module tree and calls `set_view` on every submodule that defines one, forwarding the keyword arguments you passed.
155
+
156
+
Your `set_view` method should follow these conventions:
157
+
158
+
1.**Accept keyword arguments with `None` defaults.** Each kwarg represents a configurable mode for this module. A `None` default means "leave unchanged", so views only override the modes you explicitly set.
159
+
2.**Only update the attribute when the kwarg is not `None`.** This ensures that unrelated views don't accidentally reset each other's settings.
160
+
3.**Accept `**kwargs` and return it.** This lets other submodules in the tree consume their own keyword arguments without raising errors about unexpected kwargs.
161
+
4.**Include a Google-style docstring.** The `nnx.view_info` function parses these docstrings to display human-readable information about available view options.
We can use `nnx.view_info` to inspect what view options `PrintLayer` exposes. This is especially handy when working with unfamiliar models — it lists every submodule that defines `set_view`, along with the accepted kwargs, their types, defaults, and docstring descriptions.
214
+
224
215
```{code-cell} ipython3
225
216
# Display the information for nnx.view
226
-
print(nnx.view_info(l))
217
+
print(nnx.view_info(model))
227
218
```
219
+
220
+
The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.
0 commit comments