Automatic parallel matmul? #8538
-
Hi, Based on XLA Known Issues, an XLA program runs on exactly one device. In TPU, this means that a A reinforcement learning algorithm consists of asking an agent and updating the environment based on the agent's response. The training dataset is created on the fly. Therefore, there is no single training dataset that can be scattered to all cores and train in parallel. For performance, it is possible to The problem is: Since the whole training function is jitted, it will only run in a single core. If the neural network is big, we wanted to utilize all the TPU cores when it comes to the part of the algorithm where the agent is asked (inferencing). This problem could be solved if the Can I kindly ask you guys if this is possible? Or how would you solve this problem? Thank you. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
For running multiple NNs in parallel, you could use If you want to parallelize a single matmul over multiple cores, then use |
Beta Was this translation helpful? Give feedback.
For running multiple NNs in parallel, you could use
pmap
as you noted. I believe it willjit
under the hood automatically (see #5681 (comment)). More info on how to use pmap: https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynbIf you want to parallelize a single matmul over multiple cores, then use
xmap
orpjit
https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html.https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html