You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a feed-forward neural network which is basically a composition of N functions. I want to pipeline the training procedure of said network in a multi-device environment by executing some of these functions in one device, forwarding the result to the second, execute some more functions etc. So far, I think something like the following would work:
subfunctions = [a list of jit-ed functions, each of which executes one or more network layers]
input = some provided input
for f in subfunctions:
input = f(input) #these get called asynchronously, right?
In addition, I need the final device to send back a "message" with backpropagated gradients to its previous device, which it in turn will also send back (after applying chain rule).
I also need these things to happen concurrently, i.e. call the function of device 1 again while device 2 is just beginning to process the input it got from device 1.
Is there native support in Jax for such operations, or should I be looking into something like mpi4jax?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I have a feed-forward neural network which is basically a composition of N functions. I want to pipeline the training procedure of said network in a multi-device environment by executing some of these functions in one device, forwarding the result to the second, execute some more functions etc. So far, I think something like the following would work:
In addition, I need the final device to send back a "message" with backpropagated gradients to its previous device, which it in turn will also send back (after applying chain rule).
I also need these things to happen concurrently, i.e. call the function of device 1 again while device 2 is just beginning to process the input it got from device 1.
Is there native support in Jax for such operations, or should I be looking into something like mpi4jax?
Beta Was this translation helpful? Give feedback.
All reactions