Calling jitted function from within jitting and inlining #6681
-
I wonder what happens when jax encounters a call to a jitted function inside a jitting context. Does jax inline the original call into the current jitting context or does he separately compile the two? By inspecting with For example: >>> import jax
>>> from functools import partial
>>> @partial(jax.jit, static_argnums=1)
... def test(x,i):
... y = x.real
... if i == 0:
... return y
... else:
... return test(y*y,i-1)
...
>>> jax.make_jaxpr(lambda x: test(x,2))(3.0+1j)
{ lambda ; a.
let b = xla_call[ backend=None
call_jaxpr={ lambda ; a.
let b = real a
c = mul b b
d = xla_call[ backend=None
call_jaxpr={ lambda ; a.
let b = mul a a
in (b,) }
device=None
donated_invars=(False,)
name=test ] c
in (d,) }
device=None
donated_invars=(False,)
name=test ] a
in (b,) }
>>> |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
jitted function calls are not inlined by default. We recently added #6584 which makes inlining available. In practice, though, this is more about the aesthetics of the printed jaxpr than anything about performance. My understanding is that XLA will produce the same compiled code whether the function call is inlined or not. |
Beta Was this translation helpful? Give feedback.
jitted function calls are not inlined by default. We recently added #6584 which makes inlining available.
In practice, though, this is more about the aesthetics of the printed jaxpr than anything about performance. My understanding is that XLA will produce the same compiled code whether the function call is inlined or not.