How to use JAX pmap with CPU cores #19541
Replies: 2 comments 2 replies
-
Thanks for the question. The code you pasted is correct, and works when I run it in a clean environment (i.e. one where no jax operations have previously been executed). Can ensure you're running this in a clean environment, and if it still doesn't work, can you give us more information about your system? What is the output of |
Beta Was this translation helpful? Give feedback.
-
Thanks for the reply. Yes, it is a clean environment, because I am running this code as a program using PyCharm and not as a Jupyter Notebook. When I do - Please let me know if you have any questions. I added a detailed description of the problem here . |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to use JAX pmap but I am getting the error that XLA devices aren't visible -
Here's my code -
Here's the error -
Based on this discussion, I did this
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
. Therefore I am not sure what to do now.Beta Was this translation helpful? Give feedback.
All reactions