Replies: 3 comments 6 replies
-
If you are passing a For example, I'd expect this to cause 10 recompilations of each lambda function: for i in range(10):
a = lax.cond(x, lambda: a, lambda: a + 1) Whereas this will produce the same result, but I would expect to see only a single compilation of each lambda function: same = lambda a: a
augment = lambda a: a + 1
for i in range(10):
a = lax.cond(x, same, augment, a) |
Beta Was this translation helpful? Give feedback.
-
I had tried more or less exactly what you suggested in a jit compiled function
and then only a single compile happens.
or comment out the Conclusion: Do not use inline functions with Thanks! |
Beta Was this translation helpful? Give feedback.
-
But, defining the
This re-compiles |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
After activating compile logging
JAX_LOG_COMPILES=1
I see hundreds of occurrences of what looks likejax
tracing and re-compiling theprim_fun
of alax.cond
while running a longer jax based computation:There are only very few
lax.cond()
calls in that code and I was unable to reproduce this behavior in a simple example program.I first thought that a
lax.cond()
calling an inlinelambda
function in afor
loop of ajit
compiled function could be causing this, but this does not seem to trigger this.Does anyone know what could be causing this? Since every compilation takes on the order of ~0.5s this is quite a time penalty...
Beta Was this translation helpful? Give feedback.
All reactions