Skip to content
Discussion options

You must be logged in to vote

Thanks for the question!

I think the current (undocumented) answer is that you can check the value of jax.config.jax_default_device, which is an Optional[Device], and if it's None then use jax.local_devices()[0]. That is:

import jax

def get_default_device():
  return jax.config.jax_default_device or jax.local_devices()[0]

What do you think?

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@muupan
Comment options

Answer selected by muupan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants