Commit 69d51f2
Add exercise and solution to JAX neural network lecture (#254)
* Add exercise and solution to jax_nn.md
- Added exercise asking readers to reduce MSE without increasing parameters
- Added solution with 4 optimization strategies:
1. Deeper network (6 layers, k=6) - 187 parameters
2. Deeper network + learning rate schedule (best performer)
3. ELU activation
4. SELU activation
- All strategies use same number of epochs (4000) for fair comparison
- Best strategy achieved 0.040878 MSE vs 0.041074 baseline
- Includes comprehensive comparison table and visualization
- Fixed bug in train_jax_optax function (new_θ -> θ_new)
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <[email protected]>
* misc
---------
Co-authored-by: Claude <[email protected]>1 parent e98e55d commit 69d51f2
1 file changed
+347
-24
lines changed
0 commit comments