Skip to content
Discussion options

You must be logged in to vote

It might be that your system has the wrong version of libtpu. Try installing the version pinned here: https://github.com/google/jax/blob/38884b02f72950a8d187f81420a270e46afe8889/setup.py#L27

I believe the easiest way to install the correct version is to run

$ pip install -U jax
$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@paramjeet2021
Comment options

Answer selected by paramjeet2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants