Custom JAX to MLIR translation #19364
Replies: 2 comments 1 reply
-
It's certainly possible. In some sense we already have two different MLIR lowerings of JAX right now:
You could write a new lowering, starting from one of those as inspiration. Note that these are internal APIs, so I won't promise stability, but for a research use or a prototype it should be fine. |
Beta Was this translation helpful? Give feedback.
-
Thank you very much. Do you think it is doable to extend https://github.com/google/jax/blob/main/jax/_src/interpreters/mlir.py and customise a subset of the primitives instead of writing a new lowering? Maybe replacing the MLIR definition associated with the primitives through |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I am interested in translating JAX to MLIR so that it can be used as input to other tools such as IREE. In particular, I am interested in custom translation, where an operator is translated to custom StableHLO.
I am aware that it is possible to define new JAX primitives, but I wanted to know if there is a way to define a custom translation for existing JAX primitives.
Also, I would like to know if it is possible to generate operations of other MLIR dialects besides StableHLO.
Beta Was this translation helpful? Give feedback.
All reactions