pmapping across CPU cores #13220
-
I am using a TPU v3-8 VM from the TRC program, and it is equipped with 8 TPU cores with 96 CPU cores. One experiment I would like to do is to compare the performance of 8 TPU cores with 96 CPU cores. However, it seems that with Also, just to check: does JAX only leverage TPU and ignore CPU by default? Finally, how can I leverage both CPU and TPU? The use case is running logistic regression with a large number of samples (~200000), and my code looks like this example. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
there is an XLA_FLAG that you can set to force it to be 96 CPU devices. Not sure how performant it would be tho.
refer to this issue: |
Beta Was this translation helpful? Give feedback.
there is an XLA_FLAG that you can set to force it to be 96 CPU devices. Not sure how performant it would be tho.
xla_force_host_platform_device_count
should be implemented in Jax from what I remember.refer to this issue:
#8345