-
Notifications
You must be signed in to change notification settings - Fork 51
Description
Continuing the discussion: https://discourse.julialang.org/t/node-training-performance-w-lux-vs-torchdiffeq/134860/5
Essentially it would be good if existing models, layers, etc in PyTorch could be converted to StableHLO and used within Reactant. This would reduce the barrier of entry to working with deep learning/neural networks in Julia for people who are heavily invested in the PyTorch ecosystem.
It looks like the easiest way to do this would be using TorchXLA and one of the two methods defined here: https://docs.pytorch.org/xla/master/features/stablehlo.html
Essentially we would just need the glue between ReactantPythonCallExt and one of these methods. I would partial to using the exported model approach since it likely is a smaller surface for things to go wrong and does not depend on on Jax.
@wsmoses am I understanding this correctly?
As an aside, it would be extremely useful if we could find a way to automatically bring over the weights from pretrained backbones or with whatever random initialization was used on the PyTorch side.