Skip to content

Commit fa523a9

Browse files
committed
Fix typo
1 parent 008a16f commit fa523a9

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

docs_nnx/guides/Optimization Cookbook.ipynb

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@
5959
" nnx.Linear(2,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs)),\n",
6060
" nnx.Linear(8,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs)))\n",
6161
"\n",
62-
"def nnx_loss_fn(params, x, y):\n",
63-
" return jnp.sum((y - model(params, x))**2)\n",
64-
"\n",
6562
"def nnx_loss_fn(model, x, y):\n",
6663
" return jnp.sum((model(x) - y) ** 2)"
6764
]

docs_nnx/guides/Optimization Cookbook.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ jupyter:
66
format_name: markdown
77
format_version: '1.3'
88
jupytext_version: 1.13.8
9+
kernelspec:
10+
display_name: Python 3 (ipykernel)
11+
language: python
12+
name: python3
913
---
1014

1115
# A Flax Optimization Cookbook
@@ -38,9 +42,6 @@ def nnx_model(rngs, **kwargs):
3842
nnx.Linear(2,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs)),
3943
nnx.Linear(8,8, rngs=rngs, kernel_init=ft.partial(param_init, **kwargs)))
4044

41-
def nnx_loss_fn(params, x, y):
42-
return jnp.sum((y - model(params, x))**2)
43-
4445
def nnx_loss_fn(model, x, y):
4546
return jnp.sum((model(x) - y) ** 2)
4647
```

0 commit comments

Comments
 (0)