Skip to content

Complete neural network benchmarks with Reactant, JAX, and PyTorch#1530

Merged
ChrisRackauckas merged 3 commits intoSciML:masterfrom
ChrisRackauckas-Claude:ap/nn_bench
Mar 27, 2026
Merged

Complete neural network benchmarks with Reactant, JAX, and PyTorch#1530
ChrisRackauckas merged 3 commits intoSciML:masterfrom
ChrisRackauckas-Claude:ap/nn_bench

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

Completes the work started in #1026 by @avik-pal. This PR:

  • Rebased onto current master
  • Updated all packages to latest versions (Lux 1.31.3, Reactant 0.2.240, Flux 0.16.9)
  • Added Reactant.jl inference (via @compile) and training (via TrainState + AutoEnzyme API)
  • Added JAX and PyTorch benchmarks via PythonCall, with all timing done entirely in Python to avoid Julia-to-Python call overhead
  • Completed all benchmark sections: MLP relu, MLP gelu, MLP+BatchNorm, LeNet 5, Small ResNet
  • Added both inference and training benchmarks for all model types
  • Configured GPU runner via benchmark_config.toml (["self-hosted", "gpu", "exclusive"])
  • Added CondaPkg.toml for Python dependencies (jax, jaxlib, optax, torch)

Frameworks compared

Framework Type Notes
Lux.jl Julia Primary Julia DL framework
Flux.jl Julia Comparison Julia DL framework
SimpleChains.jl Julia CPU-optimized small networks (MLP only)
Reactant.jl Julia XLA-compiled Lux models
JAX Python JIT-compiled, functional
PyTorch Python Eager mode

Architecture

  • Python models and timing utilities live in nn_benchmark_utils.py (imported via PythonCall)
  • All Python timing uses time.perf_counter() inside Python loops to avoid measuring Julia↔Python call overhead
  • JAX functions use block_until_ready() for accurate async timing

Tested locally on CPU

All 6 frameworks verified working for inference and training on CPU (no GPU available for local testing).

Supersedes #1026

Test plan

  • Verify CI runs successfully on GPU runner
  • Check that all benchmark plots render correctly
  • Verify Python dependencies install correctly via CondaPkg

🤖 Generated with Claude Code

avik-pal and others added 2 commits March 26, 2026 05:34
- Rebase onto master and update to latest package versions (Lux 1.31, Reactant 0.2.240, Flux 0.16)
- Add Reactant.jl inference (via @compile) and training (via TrainState API)
- Add JAX and PyTorch benchmarks via PythonCall with timing done entirely in Python
- Complete all benchmark sections: MLP relu, MLP gelu, MLP+BatchNorm, LeNet 5, Small ResNet
- Add both inference and training benchmarks for all model types
- Configure GPU runner via benchmark_config.toml
- Add CondaPkg.toml for Python dependencies (jax, torch, optax)
- Add .CondaPkg and __pycache__ to .gitignore

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/SciMLBenchmarksOutput that referenced this pull request Mar 27, 2026
Prepare output directory structure for the new NeuralNetworks benchmark
(SciML/SciMLBenchmarks.jl#1530) which compares Lux, Flux, SimpleChains,
Reactant, JAX, and PyTorch on common neural network workloads.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add "Neural Network Framework Benchmarks" section to the ordered docs
navigation, placed after the PINN benchmarks.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ChrisRackauckas ChrisRackauckas merged commit cdea967 into SciML:master Mar 27, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants