How to sequence parallel executions dynamically? #19610
Unanswered
benjaminvatterj
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, jax community!
I am trying to parallelize a function over batches of data, where each function call takes a variable amount of time. I have more batches than data, so I loop over pmap calls. But, because of the variable time, it often happens that many devices wait idle for the entire pmap call to conclude for a long time.
Is there a way to dynamically assign tasks to devices or other means to circumvent this issue? for example, as multiprocessing.pool assigns batches to cores as they become available.
any help would be greatly appreciated!
PS: I can't use vmap to avoid the batches > devices problem due to memory constraints.
Beta Was this translation helpful? Give feedback.
All reactions