-
|
Hi, I am trying to understand where the JIT-compiled operations will run. I am familiar with # some large array on CPU
A = jnp.zeros((50000000, 50000000), device=jax.devices("cpu")[0])
# smaller array on default device (GPU id0)
b = jnp.zeros(5)Now, if I have a JIT-compiled function which takes I am specifically forming a large array on the CPU due to memory constraints. The answer will be helpful to prevent any possible attempts to transfer that large array to GPU, and get OOM. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
|
In this case, since you explicitly specified the device of >>> import jax
>>> A = jax.numpy.ones((4, 4), device=jax.devices('cpu')[0])
>>> print(A.device, A.committed)
TFRT_CPU_0 True
>>> b = jax.numpy.arange(4.0)
>>> print(b.device, b.committed)
cuda:0 False
>>> c = A + b
>>> print(c.device, c.committed)
TFRT_CPU_0 True
For a general jit-compiled function, that question is not possible to answer without more information, because the function definition may include sharding logic that moves the data to one device or the other (for example, using one or more of the mechanisms discussed at https://docs.jax.dev/en/latest/sharded-computation.html). |
Beta Was this translation helpful? Give feedback.
-
|
@jakevdp Thank you very much for the quick reply! Currently, I don't have sharding, and the only committed array is the large one, so I guess things should work out fine. |
Beta Was this translation helpful? Give feedback.
In this case, since you explicitly specified the device of
Aand did not explicitly specify the device ofb, arrayAis committed to the device andbis uncommitted. So if you do a simple operation likeA + b(assuming shapes are compatible) then the computation will happen on the device wereAresides. For example:Fo…