Skip to content

Commit d747426

Browse files
committed
Rename NNX view functions to have common naming convention.
1 parent a138d9f commit d747426

File tree

20 files changed

+114
-115
lines changed

20 files changed

+114
-115
lines changed

docs_nnx/guides/view.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"metadata": {},
77
"source": [
88
"# Model Views\n",
9-
"This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:"
9+
"This guide covers how to use NNX \"Views\", which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, an NNX View that overwrites module attributes."
1010
]
1111
},
1212
{
@@ -25,8 +25,8 @@
2525
")\n",
2626
"\n",
2727
"# set train and eval modes\n",
28-
"train_model = nnx.view(model, deterministic=False, use_running_average=False)\n",
29-
"eval_model = nnx.view(model, deterministic=True, use_running_average=True)\n",
28+
"train_model = nnx.with_modules(model, deterministic=False, use_running_average=False)\n",
29+
"eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True)\n",
3030
"\n",
3131
"# Can see deterministic is different between train_model and eval_model\n",
3232
"assert train_model.layers[2].deterministic == False\n",
@@ -35,7 +35,7 @@
3535
"# Weights are shared between the models\n",
3636
"assert train_model.layers[0].kernel is eval_model.layers[0].kernel\n",
3737
"\n",
38-
"# Print information about kwargs for nnx.view with nnx.view_info\n",
38+
"# Print information about kwargs for nnx.with_modules with nnx.view_info\n",
3939
"print(nnx.view_info(model))"
4040
]
4141
},
@@ -125,8 +125,8 @@
125125
"metadata": {},
126126
"outputs": [],
127127
"source": [
128-
"train_model = nnx.view(model, deterministic=False)\n",
129-
"eval_model = nnx.view(model, deterministic=True)\n",
128+
"train_model = nnx.with_modules(model, deterministic=False)\n",
129+
"eval_model = nnx.with_modules(model, deterministic=True)\n",
130130
"\n",
131131
"# weights are references to the same data\n",
132132
"assert train_model.lin1.kernel is eval_model.lin1.kernel\n",
@@ -196,8 +196,8 @@
196196
"source": [
197197
"model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs)\n",
198198
"optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)\n",
199-
"train_model = nnx.view(model, deterministic=False) # training view\n",
200-
"eval_model = nnx.view(model, deterministic=True) # eval view\n",
199+
"train_model = nnx.with_modules(model, deterministic=False) # training view\n",
200+
"eval_model = nnx.with_modules(model, deterministic=True) # eval view\n",
201201
"\n",
202202
"eval_results = []\n",
203203
"for epoch in range(total_epochs):\n",
@@ -293,7 +293,7 @@
293293
"\n",
294294
"\n",
295295
"model = PrintLayer()\n",
296-
"model_print = nnx.view(model, msg='Hello, World!')\n",
296+
"model_print = nnx.with_modules(model, msg='Hello, World!')\n",
297297
"\n",
298298
"model() # nothing printed\n",
299299
"model_print() # prints \"Hello, World!\""

docs_nnx/guides/view.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jupytext:
99
---
1010

1111
# Model Views
12-
This guide covers how to use the `nnx.view` function. This function is useful for handling state in layers like `Dropout` and `BatchNorm`, which behave differently in training and evaluation. Similar to `.view` for numpy arrays, `nnx.view` allows you to set modes of the model while still sharing the same data. For a quick intro to how this function works, refer to the following example:
12+
This guide covers how to use NNX "Views", which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, an NNX View that overwrites module attributes.
1313

1414
```{code-cell}
1515
from flax import nnx
@@ -21,8 +21,8 @@ model = nnx.Sequential(
2121
)
2222
2323
# set train and eval modes
24-
train_model = nnx.view(model, deterministic=False, use_running_average=False)
25-
eval_model = nnx.view(model, deterministic=True, use_running_average=True)
24+
train_model = nnx.with_modules(model, deterministic=False, use_running_average=False)
25+
eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True)
2626
2727
# Can see deterministic is different between train_model and eval_model
2828
assert train_model.layers[2].deterministic == False
@@ -31,7 +31,7 @@ assert eval_model.layers[2].deterministic == True
3131
# Weights are shared between the models
3232
assert train_model.layers[0].kernel is eval_model.layers[0].kernel
3333
34-
# Print information about kwargs for nnx.view with nnx.view_info
34+
# Print information about kwargs for nnx.with_modules with nnx.view_info
3535
print(nnx.view_info(model))
3636
```
3737

@@ -85,8 +85,8 @@ From the model display, we can see that `Dropout` has `deterministic == False`,
8585
This is where `nnx.view` comes in. This function updates the modes for each submodule of a neural network based on the kwargs passed into the function. The underlying model weights are then shared between different views. We set up a training and evaluation version of the model below.
8686

8787
```{code-cell}
88-
train_model = nnx.view(model, deterministic=False)
89-
eval_model = nnx.view(model, deterministic=True)
88+
train_model = nnx.with_modules(model, deterministic=False)
89+
eval_model = nnx.with_modules(model, deterministic=True)
9090
9191
# weights are references to the same data
9292
assert train_model.lin1.kernel is eval_model.lin1.kernel
@@ -128,8 +128,8 @@ Now we create `train_model` and `eval_model` views up front. During the training
128128
```{code-cell}
129129
model = MyModel(in_dim, hidden_dim, out_dim, 0.1, rngs=rngs)
130130
optimizer = nnx.Optimizer(model, optax.adam(lr), wrt=nnx.Param)
131-
train_model = nnx.view(model, deterministic=False) # training view
132-
eval_model = nnx.view(model, deterministic=True) # eval view
131+
train_model = nnx.with_modules(model, deterministic=False) # training view
132+
eval_model = nnx.with_modules(model, deterministic=True) # eval view
133133
134134
eval_results = []
135135
for epoch in range(total_epochs):
@@ -201,7 +201,7 @@ class PrintLayer(nnx.Module):
201201
202202
203203
model = PrintLayer()
204-
model_print = nnx.view(model, msg='Hello, World!')
204+
model_print = nnx.with_modules(model, msg='Hello, World!')
205205
206206
model() # nothing printed
207207
model_print() # prints "Hello, World!"

docs_nnx/hijax/hijax.ipynb

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"@jax.jit\n",
5050
"def train_step(x, y):\n",
5151
" loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n",
52-
" loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad\n",
52+
" loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad\n",
5353
" optimizer.update(model, grads)\n",
5454
" return loss\n",
5555
"\n",
@@ -297,8 +297,8 @@
297297
"\n",
298298
"model = Linear(1, 3, rngs=nnx.Rngs(0))\n",
299299
"\n",
300-
"print(f\"{nnx.vars_as(model, mutable=False) = !s}\")\n",
301-
"print(f\"{nnx.vars_as(model, mutable=True) = !s}\")"
300+
"print(f\"{nnx.with_vars(model, mutable=False) = !s}\")\n",
301+
"print(f\"{nnx.with_vars(model, mutable=True) = !s}\")"
302302
]
303303
},
304304
{
@@ -317,7 +317,7 @@
317317
],
318318
"source": [
319319
"v = nnx.Variable(jnp.array(0))\n",
320-
"v_immut = nnx.vars_as(v, mutable=False)\n",
320+
"v_immut = nnx.with_vars(v, mutable=False)\n",
321321
"assert not v_immut.mutable\n",
322322
"\n",
323323
"try:\n",
@@ -355,7 +355,7 @@
355355
],
356356
"source": [
357357
"v = nnx.Variable(jnp.array(0))\n",
358-
"v_ref = nnx.vars_as(v, ref=True)\n",
358+
"v_ref = nnx.with_vars(v, ref=True)\n",
359359
"assert v_ref.ref\n",
360360
"print(v_ref)\n",
361361
"print(v_ref.get_raw_value())"
@@ -386,11 +386,11 @@
386386
}
387387
],
388388
"source": [
389-
"v_immut = nnx.vars_as(v_ref, mutable=False)\n",
389+
"v_immut = nnx.with_vars(v_ref, mutable=False)\n",
390390
"assert not v_immut.ref\n",
391391
"print(\"immutable =\", v_immut)\n",
392392
"\n",
393-
"v_ref = nnx.vars_as(v_immut, mutable=True)\n",
393+
"v_ref = nnx.with_vars(v_immut, mutable=True)\n",
394394
"assert v_ref.ref\n",
395395
"print(\"mutable =\", v_ref)"
396396
]
@@ -458,7 +458,7 @@
458458
" model = nnx.merge(graphdef, params, nondiff)\n",
459459
" return ((model(x) - y) ** 2).mean()\n",
460460
"\n",
461-
" loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad\n",
461+
" loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad\n",
462462
" optimizer.update(model, grads)\n",
463463
"\n",
464464
" return loss\n",
@@ -563,9 +563,9 @@
563563
"source": [
564564
"@jax.jit\n",
565565
"def create_model(rngs):\n",
566-
" return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)\n",
566+
" return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)\n",
567567
"\n",
568-
"model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)\n",
568+
"model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)\n",
569569
"\n",
570570
"print(\"model.linear =\", model.linear)"
571571
]

docs_nnx/hijax/hijax.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)
2929
@jax.jit
3030
def train_step(x, y):
3131
loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)
32-
loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad
32+
loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad
3333
optimizer.update(model, grads)
3434
return loss
3535
@@ -112,13 +112,13 @@ class Linear(nnx.Module):
112112
113113
model = Linear(1, 3, rngs=nnx.Rngs(0))
114114
115-
print(f"{nnx.vars_as(model, mutable=False) = !s}")
116-
print(f"{nnx.vars_as(model, mutable=True) = !s}")
115+
print(f"{nnx.with_vars(model, mutable=False) = !s}")
116+
print(f"{nnx.with_vars(model, mutable=True) = !s}")
117117
```
118118

119119
```{code-cell} ipython3
120120
v = nnx.Variable(jnp.array(0))
121-
v_immut = nnx.vars_as(v, mutable=False)
121+
v_immut = nnx.with_vars(v, mutable=False)
122122
assert not v_immut.mutable
123123
124124
try:
@@ -131,18 +131,18 @@ except Exception as e:
131131

132132
```{code-cell} ipython3
133133
v = nnx.Variable(jnp.array(0))
134-
v_ref = nnx.vars_as(v, ref=True)
134+
v_ref = nnx.with_vars(v, ref=True)
135135
assert v_ref.ref
136136
print(v_ref)
137137
print(v_ref.get_raw_value())
138138
```
139139

140140
```{code-cell} ipython3
141-
v_immut = nnx.vars_as(v_ref, mutable=False)
141+
v_immut = nnx.with_vars(v_ref, mutable=False)
142142
assert not v_immut.ref
143143
print("immutable =", v_immut)
144144
145-
v_ref = nnx.vars_as(v_immut, mutable=True)
145+
v_ref = nnx.with_vars(v_immut, mutable=True)
146146
assert v_ref.ref
147147
print("mutable =", v_ref)
148148
```
@@ -176,7 +176,7 @@ def train_step(model, optimizer, x, y):
176176
model = nnx.merge(graphdef, params, nondiff)
177177
return ((model(x) - y) ** 2).mean()
178178
179-
loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad
179+
loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad
180180
optimizer.update(model, grads)
181181
182182
return loss
@@ -226,9 +226,9 @@ except Exception as e:
226226
```{code-cell} ipython3
227227
@jax.jit
228228
def create_model(rngs):
229-
return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)
229+
return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)
230230
231-
model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)
231+
model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)
232232
233233
print("model.linear =", model.linear)
234234
```

docs_nnx/mnist_tutorial.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@
303303
"\n",
304304
"## 6. Train and evaluate the model\n",
305305
"\n",
306-
"Now, you can train the CNN model. Before the training loop, we use [`nnx.view`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation."
306+
"Now, you can train the CNN model. Before the training loop, we use [`nnx.with_modules`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation."
307307
]
308308
},
309309
{
@@ -335,8 +335,8 @@
335335
"}\n",
336336
"\n",
337337
"rngs = nnx.Rngs(0)\n",
338-
"train_model = nnx.view(model, deterministic=False, use_running_average=False)\n",
339-
"eval_model = nnx.view(model, deterministic=True, use_running_average=True)\n",
338+
"train_model = nnx.with_modules(model, deterministic=False, use_running_average=False)\n",
339+
"eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True)\n",
340340
"\n",
341341
"for step, batch in enumerate(train_ds.as_numpy_iterator()):\n",
342342
" # Run the optimization for one step and make a stateful update to the following:\n",
@@ -380,7 +380,7 @@
380380
"source": [
381381
"## 7. Perform inference on the test set\n",
382382
"\n",
383-
"Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (an `nnx.view` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance."
383+
"Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (using `nnx.with_modules` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance."
384384
]
385385
},
386386
{

docs_nnx/mnist_tutorial.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_ref
173173

174174
## 6. Train and evaluate the model
175175

176-
Now, you can train the CNN model. Before the training loop, we use [`nnx.view`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation.
176+
Now, you can train the CNN model. Before the training loop, we use [`nnx.with_modules`](https://flax.readthedocs.io/en/latest/guides/view.html) to create a `train_model` (with dropout enabled and batch norm in training mode) and an `eval_model` (with dropout disabled and batch norm using running statistics). These views share the same underlying weights, so updates during training are automatically reflected during evaluation.
177177

178178
```{code-cell} ipython3
179179
from IPython.display import clear_output
@@ -187,8 +187,8 @@ metrics_history = {
187187
}
188188
189189
rngs = nnx.Rngs(0)
190-
train_model = nnx.view(model, deterministic=False, use_running_average=False)
191-
eval_model = nnx.view(model, deterministic=True, use_running_average=True)
190+
train_model = nnx.with_modules(model, deterministic=False, use_running_average=False)
191+
eval_model = nnx.with_modules(model, deterministic=True, use_running_average=True)
192192
193193
for step, batch in enumerate(train_ds.as_numpy_iterator()):
194194
# Run the optimization for one step and make a stateful update to the following:
@@ -227,7 +227,7 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
227227

228228
## 7. Perform inference on the test set
229229

230-
Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (an `nnx.view` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.
230+
Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. Since we already have `eval_model` (using `nnx.with_modules` with `deterministic=True` and `use_running_average=True`), we can use it directly for inference. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.
231231

232232
```{code-cell} ipython3
233233
@nnx.jit

examples/nnx_toy_examples/hijax_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def loss_fn(params):
6868
model = nnx.merge(graphdef, params, nondiff)
6969
return jnp.mean((y - model(x)) ** 2)
7070

71-
grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False))
71+
grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False))
7272
optimizer.update(model, grads)
7373

7474
@jax.jit

examples/nnx_toy_examples/hijax_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def loss_fn(params):
238238

239239
# For the time being we have to use 'immutable'
240240
# as 'jax.grad' doesn't support QDD types yet.
241-
grads = jax.grad(loss_fn)(nnx.vars_as(params, is_mutable=False))
241+
grads = jax.grad(loss_fn)(nnx.with_vars(params, is_mutable=False))
242242
# 'update' mutates the optimizer's state and the params in place
243243
# so we don't need to return anything 🚀
244244
optimizer.update(params, grads)

flax/nnx/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from .module import M as M
5151
from .module import Module as Module
5252
from .module import capture as capture
53-
from .module import view as view
53+
from .module import with_modules as with_modules
5454
from .module import view_info as view_info
5555
from .module import with_attributes as with_attributes
5656
from .module import iter_children as iter_children, iter_modules as iter_modules
@@ -75,8 +75,8 @@
7575
from .graphlib import MergeContext as MergeContext
7676
from .graphlib import merge_context as merge_context
7777
from .graphlib import variables as variables
78-
from .graphlib import vars_as as vars_as
79-
from .graphlib import pure as pure
78+
from .graphlib import with_vars as with_vars
79+
from .graphlib import as_pure as as_pure
8080
from .graphlib import cached_partial as cached_partial
8181
from .graphlib import flatten as flatten
8282
from .graphlib import unflatten as unflatten
@@ -152,7 +152,7 @@
152152
from .spmd import get_named_sharding as get_named_sharding
153153
from .spmd import with_partitioning as with_partitioning
154154
from .spmd import get_abstract_model as get_abstract_model
155-
from .spmd import abstract_with_sharding as abstract_with_sharding
155+
from .spmd import as_abstract as as_abstract
156156
from .statelib import FlatState as FlatState
157157
from .statelib import State as State
158158
from .statelib import to_flat_state as to_flat_state

flax/nnx/compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
recursive_map = functools.partial(_graphlib.recursive_map, graph=True)
4040

4141
# module
42-
view = functools.partial(_module.view, graph=True)
42+
view = functools.partial(_module.with_modules, graph=True)
4343
view_info = functools.partial(_module.view_info, graph=True)
4444
iter_modules = functools.partial(_module.iter_modules, graph=True)
4545
iter_children = functools.partial(_module.iter_children, graph=True) # type: ignore[has-type]

0 commit comments

Comments
 (0)