Skip to content

Fix MNIST performance: 10x faster backward pass#10

Open
pedronahum wants to merge 1 commit intomainfrom
feature/pip-pytorch-installation
Open

Fix MNIST performance: 10x faster backward pass#10
pedronahum wants to merge 1 commit intomainfrom
feature/pip-pytorch-installation

Conversation

@pedronahum
Copy link
Owner

@pedronahum pedronahum commented Dec 12, 2025

Summary

This PR addresses #6 by fixing multiple performance bottlenecks in TaylorTorch's MNIST training.

Key Fixes

  1. Sequential backward pass O(N²) bug - The Chain._vjpCallAsFunction was redundantly calling forward passes during backpropagation. For N layers, this caused O(N²) forward passes instead of 0.

  2. BatchNorm optimization - Improved batch normalization with proper gradient computation using PyTorch's native batch_norm.

  3. Linear layer improvements - Efficient weight gradient computation.

  4. Pip-installed PyTorch - Replaced source builds with pip install torch==2.8.0 which includes MKL/MKLDNN optimizations.

Performance Results

Metric Before After Speedup
Forward 101ms 50ms 2x
Backward 1383ms 132ms 10.5x
Total (3 epochs) 2315s 385s 6x

Files Changed

Performance fixes

  • Sources/Torch/Modules/Layers/Sequential.swift - Fix O(N²) backward pass
  • Sources/Torch/Modules/Layers/BatchNorm.swift - Optimized batch norm
  • Sources/Torch/Modules/Layers/Linear.swift - Efficient gradients
  • Sources/Torch/ATen/Core/Tensor/Tensor+NN+Differentiable.swift - Matrix ops
  • Sources/ATenCXX/include/tensor_shim.hpp - C++ shim improvements

Infrastructure (pip PyTorch)

  • .github/workflows/macos-ci.yml - Use pip PyTorch
  • Dockerfile - Simplified pip installation
  • scripts/install-taylortorch-ubuntu.sh - pip approach
  • scripts/check-prerequisites.sh - Cross-platform support
  • scripts/verify-installation.sh - Cross-platform detection

Test plan

  • All 155 tests pass on Ubuntu
  • macOS CI passes
  • MNIST training ~6x faster end-to-end

Closes #6

## Performance Improvements

### Sequential backward pass fix
- Fixed O(N²) forward passes during backpropagation in Chain._vjpCallAsFunction
- Capture pullback during forward pass instead of recomputing
- Backward pass: 213ms → 63ms (3.4x improvement)

### BatchNorm optimization
- Optimized batch normalization with proper gradient computation
- Use PyTorch's native batch_norm for forward pass

### Linear layer improvements
- Improved linear layer with efficient weight gradient computation

### Matrix operations
- Added differentiable matrix multiplication reverse
- Optimized tensor NN operations

### Pip-installed PyTorch (MKL/MKLDNN)
- Replace source builds with pip install torch==2.8.0
- Enables MKL/MKLDNN optimizations (~10x faster backward)
- 30-second install vs 30-60 minute source builds

## Combined Performance Results
| Metric | Before | After | Speedup |
|--------|--------|-------|---------|
| Forward | 101ms | 50ms | 2x |
| Backward | 1383ms | 132ms | 10.5x |
| Total (3 epochs) | 2315s | 385s | 6x |

Closes #6

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@thinkpractice
Copy link
Contributor

@pedronahum nice! Btw, have you considered downloading just libtorch with a wget command, like I did here? Btw, I also tested the build with libtorch 2.9.1 and that also works. At least on Ubuntu. Are there any reasons you keep it at 2.8.0?

@pedronahum
Copy link
Owner Author

Hi @thinkpractice,
Tried to quickly use Claude Code to merge what I did on my mac in my Ubuntu NUC. So I didnt really think how to best download libtorch. Happy to make the change.

Happy to hear that 2.9 also works. Lets upgrade indeed!

To my surprise how libtorch was compiled from source was a big driver of the performance difference! So thnaks for pointing me to use the already compiled library. This will make the installation a much nicer experience.

Will not have access to my pcs until Monday. We could wait or if you want to send a PR with these changes, just let me know.

Best,

Pedro N

@thinkpractice
Copy link
Contributor

Hi @pedronahum,

Happy to hear that 2.9 also works. Lets upgrade indeed!

Nice! I think every update will come with some enhancements. So nice to be able to run it with the last version.

To my surprise how libtorch was compiled from source was a big driver of the performance difference! So thnaks for pointing me to use the already compiled library. This will make the installation a much nicer experience.

I somehow expected it a bit. Of course we can replicate the build options and process of libtorch but the guys behind pytorch of course have a lot more knowledge about how to build the library. Replicating this would take us time and if pre-build libtorch works I see no reason why. And of course it also will make it much easier to install and for other people to try.

Will not have access to my pcs until Monday. We could wait or if you want to send a PR with these changes, just let me know.

Didn't have time this weekend, but could have a lot at it today! Let me know 😄,

kind regards,

Tim

@thinkpractice
Copy link
Contributor

hi @pedronahum, pushed some changes to my linux-fixes PR. It has two scripts:

  • get-pre-build-torch.sh downloads a specific version from libtorch either for linux or mac or one with a cuda or rocm compute platform
  • setenv.sh set the environment variables by running . setenv.sh (with the dot) for building TaylorTorch.

I added some macOS auto-detection but haven't tested this. I'm also missing a mps version of libtorch? I thought this existed but couldn't find it in the downloads.

kind regards,

Tim

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MNIST Benchmark vs Pytorch

2 participants