Skip to content
Discussion options

You must be logged in to vote

When you create an array (in eager mode) it will be materialized on the default device, but will not be "committed", which means that when you operate on it, the compiler is free to move it to whichever device or sharding is most efficient. You can see that here:

import jax
jax.config.update("jax_num_cpu_devices", 2)

x = jax.numpy.arange(10)
print(f"{x.device=}\n{x.committed=}\n")
x.device=CpuDevice(id=0)
x.committed=False

When you later call device_put on that array, the buffer will be materialized on the specified device, and will also be considered "committed" to that device, meaning that the compiler will respect that data placement for any operations performed on the array:

x2 = jax.

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@mjo22
Comment options

@jakevdp
Comment options

Answer selected by mjo22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants