-
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Sep 2, 2021
Replies: 1 comment 1 reply
-
Have you seen the I suspect the issue with your code has to do with using the |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
jakevdp
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Have you seen the
colab_tpu.setup_tpu()
utility? It takes care of most of this logic for you. Here is an example: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#colab-tpu-setup (or open directly in Colab)I suspect the issue with your code has to do with using the
nightly
driver – sometimes that causes problems, which is why we generally pin the driver to a specific release that is known to work: https://github.com/google/jax/blob/f004bcb7b8763451b437202549faaea89f771f41/jax/tools/colab_tpu.py#L34-L37