-
Is there a straightforward way to get the current default device of jax? I could not find such a function in the API documentation. I tried print(jax.devices()[0])
with jax.default_device(jax.devices("cpu")[0]):
print(jax.devices()[0])
As a workaround, I found that making a new array and checking its device works. But this is not straightforward and I feel there should be a better way. print(jnp.empty(0).device())
with jax.default_device(jax.devices("cpu")[0]):
print(jnp.empty(0).device())
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! I think the current (undocumented) answer is that you can check the value of import jax
def get_default_device():
return jax.config.jax_default_device or jax.local_devices()[0] What do you think? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
I think the current (undocumented) answer is that you can check the value of
jax.config.jax_default_device
, which is anOptional[Device]
, and if it'sNone
then usejax.local_devices()[0]
. That is:What do you think?