-
I am struggling to understand how JAX manages when arrays are committed to a particular sharding. I have a very simple question about this example: https://docs.jax.dev/en/latest/notebooks/host-offloading.html#data-placement-with-device-put. When the call to In general, my question is twofold: what is happening under the hood before a sharding is specified? and when a sharding is never specified, when does JAX commit to a device? I’d like to understand how to minimize unnecessary memory transfers. I am trying to understand both |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
When you create an array (in eager mode) it will be materialized on the default device, but will not be "committed", which means that when you operate on it, the compiler is free to move it to whichever device or sharding is most efficient. You can see that here: import jax
jax.config.update("jax_num_cpu_devices", 2)
x = jax.numpy.arange(10)
print(f"{x.device=}\n{x.committed=}\n")
When you later call x2 = jax.device_put(x, jax.devices()[1])
print(f"{x2.device=}\n{x2.committed=}\n")
Note that the above is true outside JIT (i.e. when executing eagerly). Within JIT, when you create an array using Does that help answer your question? |
Beta Was this translation helpful? Give feedback.
When you create an array (in eager mode) it will be materialized on the default device, but will not be "committed", which means that when you operate on it, the compiler is free to move it to whichever device or sharding is most efficient. You can see that here:
When you later call
device_put
on that array, the buffer will be materialized on the specified device, and will also be considered "committed" to that device, meaning that the compiler will respect that data placement for any operations performed on the array: