|
6 | 6 | "metadata": {}, |
7 | 7 | "source": [ |
8 | 8 | "# Model Views\n", |
9 | | - "This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes." |
| 9 | + "This guide covers how to use NNX Views, which are useful for handling state in layers like `Dropout` and `BatchNorm` which behave differently in training and evaluation. Similar to `.view` for numpy arrays, NNX Views allow you to modify static attributes of the model while still sharing the same data. For a quick intro, consider the following example showcasing `nnx.with_modules`, a NNX View that sets module modes.\n", |
| 10 | + "\n", |
| 11 | + "NNX follows a naming convention for view-creating functions: names starting with `with_` return a new version of the input with modified module or variable attributes, while names starting with `as_` return a new tree with variables transformed into a different representation. In both cases the underlying JAX array data is shared with the original." |
10 | 12 | ] |
11 | 13 | }, |
12 | 14 | { |
|
320 | 322 | }, |
321 | 323 | { |
322 | 324 | "cell_type": "markdown", |
323 | | - "id": "1acbcc09", |
| 325 | + "id": "984b8eca", |
324 | 326 | "metadata": {}, |
325 | 327 | "source": [ |
326 | 328 | "The output shows that `PrintLayer` accepts a `msg` kwarg of type `bool` in its `set_view` method. When building larger models composed of many custom submodules, `nnx.view_info` gives you a quick summary of all the configurable modes across the entire module tree.\n", |
327 | 329 | "\n", |
| 330 | + "## Using `with_vars`\n", |
| 331 | + "\n", |
| 332 | + "{func}`nnx.with_vars <flax.nnx.with_vars>` creates a view of a module tree by replacing ``Variable`` objects with copies that have different low-level JAX flags, while leaving the underlying array data shared. Unlike `with_modules` and `with_attributes`, which change Python-level attributes on module objects, `with_vars` controls how ``Variable`` values are represented inside JAX.\n", |
| 333 | + "\n", |
| 334 | + "The flags it controls are:\n", |
| 335 | + "\n", |
| 336 | + "- **`ref`** — when `True`, each Variable's value is backed by a `jax.Ref`. This makes the module a valid pytree leaf for `jax.tree.map` and other JAX utilities that treat refs as mutable state.\n", |
| 337 | + "- **`hijax`** — when `True`, Variables participate in JAX's *hijax* protocol and become first-class JAX values that can flow through `jax.grad`, `jax.jit`, and similar transforms without an explicit split/merge step.\n", |
| 338 | + "- **`mutable`** — when `True`, marks Variables as mutable within a JAX transform.\n", |
| 339 | + "\n", |
| 340 | + "The `only` argument accepts a {doc}`Filter <filters_guide>` to restrict which Variables are affected; unmatched Variables are returned as-is (shared with the original)." |
| 341 | + ] |
| 342 | + }, |
| 343 | + { |
| 344 | + "cell_type": "code", |
| 345 | + "execution_count": null, |
| 346 | + "id": "0938e1f6", |
| 347 | + "metadata": {}, |
| 348 | + "outputs": [], |
| 349 | + "source": [ |
| 350 | + "from flax import nnx\n", |
| 351 | + "import jax\n", |
| 352 | + "import jax.numpy as jnp\n", |
| 353 | + "\n", |
| 354 | + "class SimpleModel(nnx.Module):\n", |
| 355 | + " def __init__(self, rngs):\n", |
| 356 | + " self.linear = nnx.Linear(2, 3, rngs=rngs)\n", |
| 357 | + "\n", |
| 358 | + "model = SimpleModel(nnx.Rngs(0))\n", |
| 359 | + "\n", |
| 360 | + "# ref=True: expose Variable values as JAX refs so jax.tree.map can update them\n", |
| 361 | + "ref_model = nnx.with_vars(model, ref=True)\n", |
| 362 | + "ref_model = jax.tree.map(lambda x: x * 2, ref_model)\n", |
| 363 | + "\n", |
| 364 | + "# The original model's kernel is unchanged; ref_model has doubled values\n", |
| 365 | + "assert model.linear.kernel is not ref_model.linear.kernel" |
| 366 | + ] |
| 367 | + }, |
| 368 | + { |
| 369 | + "cell_type": "markdown", |
| 370 | + "id": "67c71d62", |
| 371 | + "metadata": {}, |
| 372 | + "source": [ |
| 373 | + "Use the `only` filter to convert only a subset of Variables:" |
| 374 | + ] |
| 375 | + }, |
| 376 | + { |
| 377 | + "cell_type": "code", |
| 378 | + "execution_count": null, |
| 379 | + "id": "254fe344", |
| 380 | + "metadata": {}, |
| 381 | + "outputs": [], |
| 382 | + "source": [ |
| 383 | + "# only convert Param variables, leave BatchStat variables unchanged\n", |
| 384 | + "ref_params = nnx.with_vars(model, ref=True, only=nnx.Param)" |
| 385 | + ] |
| 386 | + }, |
| 387 | + { |
| 388 | + "cell_type": "markdown", |
| 389 | + "id": "1acbcc09", |
| 390 | + "metadata": {}, |
| 391 | + "source": [ |
328 | 392 | "## Using `with_attributes`\n", |
329 | 393 | "\n", |
330 | 394 | "If you are working with modules that don't implement the `set_view` API, you can use {func}`nnx.with_attributes <flax.nnx.with_attributes>` to create views by directly replacing their attributes. Like `nnx.with_modules`, it returns a new instance that shares jax arrays with the original, leaving the original unchanged." |
|
412 | 476 | "id": "bf521e45", |
413 | 477 | "metadata": {}, |
414 | 478 | "source": [ |
415 | | - "Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`." |
| 479 | + "Here `recursive_map` visited each node, and when it found an `nnx.Linear` instance it created a `NoisyLinear`, swapped in the original `Linear` as its inner layer, and returned it. The original `model` is unchanged and its weights are shared with `noisy_model`.\n", |
| 480 | + "\n", |
| 481 | + "## Other NNX views\n", |
| 482 | + "\n", |
| 483 | + "Several other NNX functions follow the `with_` / `as_` naming convention and produce views or transformed trees:\n", |
| 484 | + "\n", |
| 485 | + "- {func}`nnx.as_pure <flax.nnx.as_pure>` — strips all ``Variable`` wrappers from a pytree and returns the raw inner values. This is useful for serialization or export, where Variable metadata is not needed.\n", |
| 486 | + "\n", |
| 487 | + " ```python\n", |
| 488 | + " _, state = nnx.split(model)\n", |
| 489 | + " pure_state = nnx.as_pure(state) # Variable wrappers removed; plain arrays remain\n", |
| 490 | + " ```\n", |
| 491 | + "\n", |
| 492 | + "- {func}`nnx.as_abstract <flax.nnx.as_abstract>` — annotates the abstract ``Variable`` objects produced by {func}`nnx.eval_shape` with sharding information derived from each Variable's `out_sharding` metadata. Used when working with JAX auto-sharding meshes.\n", |
| 493 | + "\n", |
| 494 | + " ```python\n", |
| 495 | + " with jax.set_mesh(mesh):\n", |
| 496 | + " abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))\n", |
| 497 | + " abs_model = nnx.as_abstract(abs_model) # sharding attached to abstract vars\n", |
| 498 | + " ```\n", |
| 499 | + "\n", |
| 500 | + "- {func}`nnx.with_rngs <flax.nnx.rnglib.with_rngs>` — returns a copy of a pytree with ``RngStream`` objects split or forked according to filter rules. Used to prepare RNG state before JAX transforms like `vmap` that require per-device or per-replica keys.\n", |
| 501 | + "\n", |
| 502 | + " ```python\n", |
| 503 | + " # Split params stream into 4 keys (one per vmap replica); fork the rest\n", |
| 504 | + " vmapped_rngs = nnx.with_rngs(rngs, split={'params': 4}, fork=...)\n", |
| 505 | + " ```" |
416 | 506 | ] |
417 | 507 | } |
418 | 508 | ], |
|
0 commit comments