-
In the tutorial,
def g_inner_jitted_poorly(x, n):
i = 0
while i < n:
# Don't do this!
i = jax.jit(unjitted_loop_body)(i)
return x + i this is quite helpful, but I am a little confused about how to jit the following? I want a function to output a function in order to pass some parameter? def make_compute_fn(param1, param2):
...
def compute(param1, param2):
...
return compute In order to have a jitted function, is this appropriate to do this? def make_compute_fn(param1, param2):
...
@jit
def compute(param1, param2):
...
return compute Again, thanks for posting the Jax tutorials, really appreciate that! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Your last code snippet, returning the jitted function, would be the recommended way to do this. The example you quote from the tutorial is about jitting within a loop, which leads to repeated JIT-compilation overhead. If you're creating a jitted function once within a factory function, that concern about repeated compilation overhead does not apply. |
Beta Was this translation helpful? Give feedback.
Your last code snippet, returning the jitted function, would be the recommended way to do this.
The example you quote from the tutorial is about jitting within a loop, which leads to repeated JIT-compilation overhead. If you're creating a jitted function once within a factory function, that concern about repeated compilation overhead does not apply.