Skip to content

PyTorch StableHLO Support #2065

@csvance

Description

@csvance

Continuing the discussion: https://discourse.julialang.org/t/node-training-performance-w-lux-vs-torchdiffeq/134860/5

Essentially it would be good if existing models, layers, etc in PyTorch could be converted to StableHLO and used within Reactant. This would reduce the barrier of entry to working with deep learning/neural networks in Julia for people who are heavily invested in the PyTorch ecosystem.

It looks like the easiest way to do this would be using TorchXLA and one of the two methods defined here: https://docs.pytorch.org/xla/master/features/stablehlo.html

Essentially we would just need the glue between ReactantPythonCallExt and one of these methods. I would partial to using the exported model approach since it likely is a smaller surface for things to go wrong and does not depend on on Jax.

@wsmoses am I understanding this correctly?

As an aside, it would be extremely useful if we could find a way to automatically bring over the weights from pretrained backbones or with whatever random initialization was used on the PyTorch side.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions