-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Description & Motivation
Hi,
a while back I did some work with Jax but I didn't want to miss the nice logging and boilerplate code reduction features of Lightning which is built around PyTorch.
After some thinking and tinkering I settled on a solution which required three steps:
- Set `self.automatic_optimization=False' to do our own Jax based optimization
- Load data in numpy mode to make it accessible to Jax
- With `self.automatic_optimization=False' execute the forward, backward and gradient optimization step myself.
A repo with examples can be found here: https://github.com/ludwigwinkler/JaxLightning (the name is somewhat unoriginal).
With those three changes I was able to run and train Jax models (the Equinox flavor) within Lightning with the logging, code structure and general data and model flow of a Lightning module.
Pitch
@lantiga was so kind to approach me to potentially integrate this idea into Lightning itself and I'd be more than happy to contribute this idea and implementation to Lightning.
Since the way I got Jax to run within Lightning is quite lightweight I'm hopeful that an integration of Jax could be possible.
Fortunately, from my rudimentary non-large-scale experience with Jax, it takes care of a lot of placement and parallelization under the hood.
One thing to keep in mind is that the Jax ecosystem is less coalesced and there are several neural network and optimization frameworks build on top of Jax which only provides a tool set. This is in contrast to PyTorch which already includes the batteries with torch.optim. In that sense Lightning could either support a predefined list of packages or outsource a bit more basic implementation to the user and provide higher level functions like gradient accumulation for example.
The (incomplete) list of things I can I think of that maybe somebody with more knowledge of Lightning could answer is:
- Could Jax be detected automatically or would there be an extra flag
framework=[ jax | pytorch ]? - Would PyTorch and Jax be allowed to exist in the same module side by side? That could become messy.
- What is the depth of support for individual jax packages (Equinox vs Flax vs Haiku)?
- Checkpointing
- (probably so much more that I haven't thought of)
Happy to discuss the motivation, interest, feasibility and possible implementation ideas.
Alternatives
No response
Additional context
No response