JAX on TPU V4 #10357
Unanswered
agemagician
asked this question in
Q&A
JAX on TPU V4
#10357
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
Currently, we are testing JAX in TPU V4, and we have a question regarding the number of cores that Jax should detect.
According to the TPU-V4 early access documentation, we should install liptpu from here : "https://storage.googleapis.com/jax-releases/libtpu_releases.html" , which makes JAX sees only 4 cores.
However, I came across this readme from bigscience workshop:
https://github.com/bigscience-workshop/architecture-objective/blob/b5bfb9fb49fc86b2352c0b34e8a9dacf3ea052b8/bigscience/docs/tpu.md
Which uses a different custom "libtpu" for TPUV4 , which makes JAX sees the 8 cores.
Could you please confirm which version of libtpu should use in JAX for better performance and the differences when Jax detects the 4 cores or 8 cores in TPU V4 ?
Beta Was this translation helpful? Give feedback.
All reactions