Skip to content
Discussion options

You must be logged in to vote

It's an internal API, and thus subject to breakage without much warning, but you can do

from jax.interpreters import mlir

mlir.register_lowering(flip_p, mlir.lower_fun(flip_impl, multiple_results=flip_p.multiple_results))

The mlir.lower_fun utility takes a Python callable and generates an MLIR/XLA lowering rule which when called traces that callable and lowers the resulting jaxpr.

(I didn't actually try running this, it's from memory, so I apologize if I got it slightly wrong!)

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mattjj
Comment options

@femtomc
Comment options

Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants