-
Dear all, I need to get the id of the device where my computation takes place. Say for example, I have 8 gpu on my host and I need to multiply a tensor by a specific number per GPU There is a |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
More precisely, I'm looking for the jax implementation of the following tensorflow function
|
Beta Was this translation helpful? Give feedback.
-
I assume you're using pmap? One option is to pass |
Beta Was this translation helpful? Give feedback.
I assume you're using pmap? One option is to pass
jnp.arange(jax.device_count())
into the pmap, so every GPU will get a different value. Alternatively, you can use jax.lax.axis_index.