Skip to content

Commit d3b2abd

Browse files
Fix nnx guide error (#2183)
* Update keras_nnx_guide.py * Update keras_nnx_guide.ipynb * Update keras_nnx_guide.md
1 parent 72ecb2f commit d3b2abd

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

guides/ipynb/keras_nnx_guide.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@
342342
" y_pred = model_(x)\n",
343343
" return jnp.mean((y - y_pred) ** 2)\n",
344344
"\n",
345-
" grads = nnx.grad(loss_fn, wrt=trainable_var)(model)\n",
345+
" diff_state = nnx.DiffState(0, trainable_var)\n",
346+
" grads = nnx.grad(loss_fn, argnums=diff_state)(model)\n",
346347
" optimizer.update(model, grads)\n",
347348
"\n",
348349
"\n",
@@ -506,4 +507,4 @@
506507
},
507508
"nbformat": 4,
508509
"nbformat_minor": 0
509-
}
510+
}

guides/keras_nnx_guide.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ def loss_fn(model_):
208208
y_pred = model_(x)
209209
return jnp.mean((y - y_pred) ** 2)
210210

211-
grads = nnx.grad(loss_fn, wrt=trainable_var)(model)
211+
diff_state = nnx.DiffState(0, trainable_var)
212+
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
212213
optimizer.update(model, grads)
213214

214215

guides/md/keras_nnx_guide.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,8 @@ def train_step(model, optimizer, batch):
223223
y_pred = model_(x)
224224
return jnp.mean((y - y_pred) ** 2)
225225

226-
grads = nnx.grad(loss_fn, wrt=trainable_var)(model)
226+
diff_state = nnx.DiffState(0, trainable_var)
227+
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
227228
optimizer.update(model, grads)
228229

229230

0 commit comments

Comments
 (0)