You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
EDIT3: The numpyro dependency has been removed from the code so that the issue is reproducible with jax (and numpy) only.
I have a function f that uses jax and numpyro (https://github.com/pyro-ppl/numpyro). I have a problem where execution with GPU as the default device affects subsequent executions with CPU as the default device. This only happens when JIT enabled.
I could not share the content of f, but it must be stateless. Here is the issue:
# cpuwithjax.default_device(jax.devices("cpu")[0]):
y1=f()
# cpu againwithjax.default_device(jax.devices("cpu")[0]):
y2=f()
# gpuwithjax.default_device(jax.devices("gpu")[0]):
y3=f()
# cpu yet againwithjax.default_device(jax.devices("cpu")[0]):
y4=f()
Here, I expect y1 == y2 == y4 since they are executed on the same device. However, while y1 == y2, y4 is different from them, implying that the in-gpu computation of y3 affects the in-cpu computation of y4. Indeed, even if I comment out y1 and y2, it does not affect y4. If comment out y3, y1 == y2 == y4 If I set JAX_DISABLE_JIT=1, then the issue disappears, with y1 == y2 == y4, which makes me suspect that this is related to JIT cache.
Since I haven't shared the content of f, I guess it is difficult to see what is going on inside it just from the information above, but I'd appreciate any advices or hints on what is going on or how to debug it.
I'm using jax==0.3.21 and jaxlib==0.3.20+cuda11.cudnn8.
EDIT: Here is the code to reproduce the issue. I'm using numpyro==0.10.1.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
EDIT2: I asked about the issue in numpyro's issues too: pyro-ppl/numpyro#1496.EDIT3: The numpyro dependency has been removed from the code so that the issue is reproducible with jax (and numpy) only.
I have a function
f
that uses jaxand numpyro (https://github.com/pyro-ppl/numpyro). I have a problem where execution with GPU as the default device affects subsequent executions with CPU as the default device. This only happens when JIT enabled.I could not share the content ofHere is the issue:f
, but it must be stateless.Here, I expect
y1 == y2 == y4
since they are executed on the same device. However, whiley1 == y2
, y4 is different from them, implying that the in-gpu computation of y3 affects the in-cpu computation of y4. Indeed, even if I comment out y1 and y2, it does not affect y4. If comment out y3,y1 == y2 == y4
If I setJAX_DISABLE_JIT=1
, then the issue disappears, withy1 == y2 == y4
, which makes me suspect that this is related to JIT cache.Since I haven't shared the content ofI'd appreciate any advices or hints on what is going on or how to debug it.f
, I guess it is difficult to see what is going on inside it just from the information above, butI'm using jax==0.3.21 and jaxlib==0.3.20+cuda11.cudnn8.
EDIT: Here is the code to reproduce the issue.
I'm using numpyro==0.10.1.output (without JAX_DISABLE_JIT=1):
output (with JAX_DISABLE_JIT=1):
Beta Was this translation helpful? Give feedback.
All reactions