-
The order of devices in the output of |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question! In auto-sharding contexts, JAX does its best to lay out data in a manner that is most efficient given the topology of the device(s). In the case of TPU, the topology may be a 3D mesh or a torus, depending on the system (look for Interconnect Topology in the Cloud TPU Docs). You can get a representation of the topology you are working with by printing the coordinates of the TPU devices: import jax
print([d.coords for d in jax.devices()]) Does that answer your question? |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! In auto-sharding contexts, JAX does its best to lay out data in a manner that is most efficient given the topology of the device(s). In the case of TPU, the topology may be a 3D mesh or a torus, depending on the system (look for Interconnect Topology in the Cloud TPU Docs).
You can get a representation of the topology you are working with by printing the coordinates of the TPU devices:
Does that answer your question?