JIT with closure function #20091
-
Hi all, I have a function generator which is like this
That is, I have a My question is, will it lead to retracing each time I input a new |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 3 replies
-
Do you have a typo in your code? It appears the function is trying to return a reference to itself. |
Beta Was this translation helpful? Give feedback.
-
The answer to whether this will trigger a recompilation each time it is called depends on exactly when and where you're using So with that in mind, if you do something like this, it will always re-trace and re-compile on the second call: jax.jit(func_generator(base_fn))(params)
jax.jit(func_generator(base_fn))(params2) # re-traces always If you do something like this, you avoid creating another instance of the closed-over function, so the second call will hit the jit_f = jax.jit(func_generator(base_fn))
jit_f(params)
jit_f(params2) # does not re-trace if params2 is compatible with params |
Beta Was this translation helpful? Give feedback.
-
Thanks for the quick reply! I am confused about one statement you write: As written, My actual function is a bit more complex than it, is it possible to use some tools to check whether re-tracing happens or not?
|
Beta Was this translation helpful? Give feedback.
-
Thanks, got it! The inner is already cached right? So even it generates another function object comp_fn, which is inside the inner, it would not lead to re-compile as it caches all operations inside the jit_f. A small follow-up question, I used jit outside a big for-loop, and each iteration I call func_generator to generate inner functions, and calculate something via putting parameters to jit_f. The base_fn is wrapper ouside the for-loop using closure function like the one shown above. I am not sure whether this design will lead to problem, so do you have suggestion to check whether retracing/memory problem happens since each time I generate new function objects inside?
|
Beta Was this translation helpful? Give feedback.
The inner function is the return value of
func_generator
, so when you JIT-compile that return value you are JIT-compiling (and cacheing)inner
. Does that make sense?If you have everything under an outer JIT, then the relevant JIT cache / compilation is for that outer-jit function. Each inner function call will be inlined into the jaxpr of that outer function, and that will be compiled once. There will not be any cache for functions created within a JIT-compiled function.
In general there is no issue with this approach, but keep in mind that if you wrap a
for
loop injit
, that…