Changing lowering rule for a primitive #19897
-
I want to to export some statistical computation to tflite and got a problem with missing import jax
import jax.numpy as jnp
from jax.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
if __name__ == '__main__':
k = jax.random.key(4)
old_impl = jax.lax.erf_inv_p.impl
def new_impl(x):
print(x)
return 0.5
jax.lax.erf_inv_p.def_impl(new_impl)
def lowering(ctx, xx):
return [hlo.ConstantOp(...)]
mlir.register_lowering(jax.lax.erf_inv_p, lowering,platform='cpu')
@jax.jit
def new_sample(k):
with jax.disable_jit():
x = jax.random.normal(k)
return x
x = new_sample(k)
print(x) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Unless you want to directly call some lowering = mlir.lower_fun(new_impl, multiple_results=False) If you're looking for the syntax to use def lowering(ctx, xx):
return [hlo.ConstantOp(
mlir.ir.DenseElementsAttr.get(
np.array(0.5, dtype=np.float32),
type=mlir.ir.F32Type.get())
).result] |
Beta Was this translation helpful? Give feedback.
Unless you want to directly call some
hlo
method, the easiest way to bootstrap a lowering rule is by applyingmlir.lower_fun
to theimpl
rule:If you're looking for the syntax to use
ConstantOp
, it would be something like this: