|
49 | 49 | "@jax.jit\n", |
50 | 50 | "def train_step(x, y):\n", |
51 | 51 | " loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n", |
52 | | - " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad\n", |
| 52 | + " loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(model, mutable=False)) # tmp fix for jax.grad\n", |
53 | 53 | " optimizer.update(model, grads)\n", |
54 | 54 | " return loss\n", |
55 | 55 | "\n", |
|
297 | 297 | "\n", |
298 | 298 | "model = Linear(1, 3, rngs=nnx.Rngs(0))\n", |
299 | 299 | "\n", |
300 | | - "print(f\"{nnx.vars_as(model, mutable=False) = !s}\")\n", |
301 | | - "print(f\"{nnx.vars_as(model, mutable=True) = !s}\")" |
| 300 | + "print(f\"{nnx.with_vars(model, mutable=False) = !s}\")\n", |
| 301 | + "print(f\"{nnx.with_vars(model, mutable=True) = !s}\")" |
302 | 302 | ] |
303 | 303 | }, |
304 | 304 | { |
|
317 | 317 | ], |
318 | 318 | "source": [ |
319 | 319 | "v = nnx.Variable(jnp.array(0))\n", |
320 | | - "v_immut = nnx.vars_as(v, mutable=False)\n", |
| 320 | + "v_immut = nnx.with_vars(v, mutable=False)\n", |
321 | 321 | "assert not v_immut.mutable\n", |
322 | 322 | "\n", |
323 | 323 | "try:\n", |
|
355 | 355 | ], |
356 | 356 | "source": [ |
357 | 357 | "v = nnx.Variable(jnp.array(0))\n", |
358 | | - "v_ref = nnx.vars_as(v, ref=True)\n", |
| 358 | + "v_ref = nnx.with_vars(v, ref=True)\n", |
359 | 359 | "assert v_ref.ref\n", |
360 | 360 | "print(v_ref)\n", |
361 | 361 | "print(v_ref.get_raw_value())" |
|
386 | 386 | } |
387 | 387 | ], |
388 | 388 | "source": [ |
389 | | - "v_immut = nnx.vars_as(v_ref, mutable=False)\n", |
| 389 | + "v_immut = nnx.with_vars(v_ref, mutable=False)\n", |
390 | 390 | "assert not v_immut.ref\n", |
391 | 391 | "print(\"immutable =\", v_immut)\n", |
392 | 392 | "\n", |
393 | | - "v_ref = nnx.vars_as(v_immut, mutable=True)\n", |
| 393 | + "v_ref = nnx.with_vars(v_immut, mutable=True)\n", |
394 | 394 | "assert v_ref.ref\n", |
395 | 395 | "print(\"mutable =\", v_ref)" |
396 | 396 | ] |
|
458 | 458 | " model = nnx.merge(graphdef, params, nondiff)\n", |
459 | 459 | " return ((model(x) - y) ** 2).mean()\n", |
460 | 460 | "\n", |
461 | | - " loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad\n", |
| 461 | + " loss, grads = jax.value_and_grad(loss_fn)(nnx.with_vars(params, mutable=False)) # immutable for jax.grad\n", |
462 | 462 | " optimizer.update(model, grads)\n", |
463 | 463 | "\n", |
464 | 464 | " return loss\n", |
|
563 | 563 | "source": [ |
564 | 564 | "@jax.jit\n", |
565 | 565 | "def create_model(rngs):\n", |
566 | | - " return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)\n", |
| 566 | + " return nnx.with_vars((Block(2, 64, 3, rngs=rngs)), hijax=False)\n", |
567 | 567 | "\n", |
568 | | - "model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)\n", |
| 568 | + "model = nnx.with_vars(create_model(nnx.Rngs(0)), hijax=True)\n", |
569 | 569 | "\n", |
570 | 570 | "print(\"model.linear =\", model.linear)" |
571 | 571 | ] |
|
0 commit comments