Skip to content
Discussion options

You must be logged in to vote

Hi - thanks for the question! In auto-sharding contexts, JAX does its best to lay out data in a manner that is most efficient given the topology of the device(s). In the case of TPU, the topology may be a 3D mesh or a torus, depending on the system (look for Interconnect Topology in the Cloud TPU Docs).

You can get a representation of the topology you are working with by printing the coordinates of the TPU devices:

import jax
print([d.coords for d in jax.devices()])

Does that answer your question?

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@yixiaoer
Comment options

Answer selected by yixiaoer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants