Skip to content
Discussion options

You must be logged in to vote

Have you seen the colab_tpu.setup_tpu() utility? It takes care of most of this logic for you. Here is an example: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#colab-tpu-setup (or open directly in Colab)

I suspect the issue with your code has to do with using the nightly driver – sometimes that causes problems, which is why we generally pin the driver to a specific release that is known to work: https://github.com/google/jax/blob/f004bcb7b8763451b437202549faaea89f771f41/jax/tools/colab_tpu.py#L34-L37

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@TMcMac
Comment options

Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants