Replies: 1 comment 4 replies
-
Hard to say without more information, but one thing to keep in mind is that JAX compilation unrolls python control flow. So something like @jit
def f(x):
for i in range(1000000):
x += 1
return x would result in 1,000,000 |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey,
What are some best for writing jax code that doesn't OOM during compilation?
If understand correctly it uses scans, but I haven't found any specific information related to this.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions