-
This is a general question. I found my jax jitted code is unexpectedly slow and memory keeps increasing. I suspect one part of jitted code is tracing multiple times. Are there any convenient tools to check retracing, and profiling? I know I can put print in jitted function, and check how many times it prints. But do we have better approaches, I found in TF, that users can use |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Thanks for the question! You can use the
We'd like to improve the tooling around this further, in particular to provide some explanation as to why a function is getting recompiled. At the moment the information is pretty minimal. |
Beta Was this translation helpful? Give feedback.
-
With respect to handling the "why": given some problematic function @jax.jit
def big_function_mysteriously_being_recompiled(arg1, arg2, ...):
...
big_function_mysteriously_being_recompiled(arg1, arg2, ...) then I define a helper function: @functools.partial(jax.jit, static_argnum=1)
def why(arg, name):
print(name) and change the callsite to: why(arg1, "arg1")
why(arg2, "arg2")
...
big_function_mysteriously_being_recompiled(arg1, arg2, ...) which then makes clear which argument is changing. |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
You can use the
jax_log_compiles
option. You can set it in any of these ways:with jax.log_compiles(True): ...
jax.config.update('jax_log_compiles', True)
JAX_LOG_COMPILES=1
(or any truthy value)absl
together with theabsl
command-line flag--jax_log_compiles=1
We'd like to improve the tooling around this further, in particular to provide some explanation as to why a function is getting recompiled. At the moment the information is pretty minimal.