-
Hi all! I've been writing new primitives, and defining them according to the documentation tutorial. Here's an example -- a probabilistic flip_p = jc.Primitive("flip")
def flip(key, p):
return flip_p.bind(key, p)
def flip_abstract_call(key, p):
return key, abstract_arrays.ShapedArray(p.shape, bool)
def flip_impl(key, p):
key, sub_key = jax.random.split(key)
return key, jax.random.bernoulli(key, p)
flip_p.def_abstract_eval(flip_abstract_call)
flip_p.def_impl(flip_impl)
flip_p.multiple_results = True Now, I really want to define the XLA lower semantics of this primitive -- but I'd just like the lowering to follow the implementation. Is there an easy way to do this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
It's an internal API, and thus subject to breakage without much warning, but you can do from jax.interpreters import mlir
mlir.register_lowering(flip_p, mlir.lower_fun(flip_impl, multiple_results=flip_p.multiple_results)) The (I didn't actually try running this, it's from memory, so I apologize if I got it slightly wrong!) |
Beta Was this translation helpful? Give feedback.
It's an internal API, and thus subject to breakage without much warning, but you can do
The
mlir.lower_fun
utility takes a Python callable and generates an MLIR/XLA lowering rule which when called traces that callable and lowers the resulting jaxpr.(I didn't actually try running this, it's from memory, so I apologize if I got it slightly wrong!)