Asynchronous dispatch in MPMDs? #6678
Unanswered
epignatelli
asked this question in
Q&A
Replies: 1 comment
-
This does not answer the question directly, but maybe adds more information on the original issue of implementing actor-learner architectures in RL. https://jax.readthedocs.io/en/latest/multi_process.html |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
One usual application of asynchronous computation in reinforcement learning is asynchronous methods, e.g. A3C and IMPALA.
Multiple actors are collecting experience in multiple environments concurrently using the GPU, and data is often gathered either on the cpu or on separate, master GPU.
I was trying to implement these using MPMD in jax, but I am not sure what I am writing work as I thought.
What rules does the asynch dispatch in jax follow for MPMD?
In particular, if I have two python calls, each of which makes two xla calls to two separate devices, will they be executed asynchronously?
I tried to run some tests using the code below, and it seems like the second GPU is waiting for the first GPU?
Here's the colab equivalent:
https://colab.research.google.com/drive/1TFPGndv6UaGsL_S0s7h1Bh5lfZYhwtB5?usp=sharing
Control
Multiple programs, single GPU, twice, then reduce on separate GPU
Multiple programs, multiple GPUs, then reduce on a separate GPU
Will the call to
a_jit_dev_0
anda_jit_dev_1
be executed asynchronously?From a quick tests it seems like the computation flickers between the two cpus.
Beta Was this translation helpful? Give feedback.
All reactions