Fix MNIST performance: 10x faster backward pass#10
Conversation
## Performance Improvements ### Sequential backward pass fix - Fixed O(N²) forward passes during backpropagation in Chain._vjpCallAsFunction - Capture pullback during forward pass instead of recomputing - Backward pass: 213ms → 63ms (3.4x improvement) ### BatchNorm optimization - Optimized batch normalization with proper gradient computation - Use PyTorch's native batch_norm for forward pass ### Linear layer improvements - Improved linear layer with efficient weight gradient computation ### Matrix operations - Added differentiable matrix multiplication reverse - Optimized tensor NN operations ### Pip-installed PyTorch (MKL/MKLDNN) - Replace source builds with pip install torch==2.8.0 - Enables MKL/MKLDNN optimizations (~10x faster backward) - 30-second install vs 30-60 minute source builds ## Combined Performance Results | Metric | Before | After | Speedup | |--------|--------|-------|---------| | Forward | 101ms | 50ms | 2x | | Backward | 1383ms | 132ms | 10.5x | | Total (3 epochs) | 2315s | 385s | 6x | Closes #6 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
|
@pedronahum nice! Btw, have you considered downloading just |
|
Hi @thinkpractice, Happy to hear that 2.9 also works. Lets upgrade indeed! To my surprise how libtorch was compiled from source was a big driver of the performance difference! So thnaks for pointing me to use the already compiled library. This will make the installation a much nicer experience. Will not have access to my pcs until Monday. We could wait or if you want to send a PR with these changes, just let me know. Best, Pedro N |
|
Hi @pedronahum,
Nice! I think every update will come with some enhancements. So nice to be able to run it with the last version.
I somehow expected it a bit. Of course we can replicate the build options and process of
Didn't have time this weekend, but could have a lot at it today! Let me know 😄, kind regards, Tim |
|
hi @pedronahum, pushed some changes to my
I added some macOS auto-detection but haven't tested this. I'm also missing a kind regards, Tim |
Summary
This PR addresses #6 by fixing multiple performance bottlenecks in TaylorTorch's MNIST training.
Key Fixes
Sequential backward pass O(N²) bug - The
Chain._vjpCallAsFunctionwas redundantly calling forward passes during backpropagation. For N layers, this caused O(N²) forward passes instead of 0.BatchNorm optimization - Improved batch normalization with proper gradient computation using PyTorch's native batch_norm.
Linear layer improvements - Efficient weight gradient computation.
Pip-installed PyTorch - Replaced source builds with
pip install torch==2.8.0which includes MKL/MKLDNN optimizations.Performance Results
Files Changed
Performance fixes
Sources/Torch/Modules/Layers/Sequential.swift- Fix O(N²) backward passSources/Torch/Modules/Layers/BatchNorm.swift- Optimized batch normSources/Torch/Modules/Layers/Linear.swift- Efficient gradientsSources/Torch/ATen/Core/Tensor/Tensor+NN+Differentiable.swift- Matrix opsSources/ATenCXX/include/tensor_shim.hpp- C++ shim improvementsInfrastructure (pip PyTorch)
.github/workflows/macos-ci.yml- Use pip PyTorchDockerfile- Simplified pip installationscripts/install-taylortorch-ubuntu.sh- pip approachscripts/check-prerequisites.sh- Cross-platform supportscripts/verify-installation.sh- Cross-platform detectionTest plan
Closes #6