You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm interested in implementing training strategies for neural networks that are similar to GPipe, i.e. strategies that place different parts of a neural network in different devices and then training them in a pipelined manner. This could mean, for example, that the second data batch is fed into the first layer of a net while the second layer is still processing the first batch, and it could also mean that errors are backpropagated in the net simultaneously with data batches.
So far, such schemes look like message-passing between devices to me, coupled with some aggregation of results (e.g. computed parameter gradients) to be centrally processed (e.g. used by an optimizer). What is a good place to start if I want to implement things like that? I'm pretty new to Jax, I've read about its asynchronous execution paradigm and about functions like pmap, xmap and I've come across the mpi4jax library. However, I'm rather confused about what's the best way to implement things like that in Jax, and about how much "wheel reinvention" I could avoid. I appreciate any tips!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi there!
I'm interested in implementing training strategies for neural networks that are similar to GPipe, i.e. strategies that place different parts of a neural network in different devices and then training them in a pipelined manner. This could mean, for example, that the second data batch is fed into the first layer of a net while the second layer is still processing the first batch, and it could also mean that errors are backpropagated in the net simultaneously with data batches.
So far, such schemes look like message-passing between devices to me, coupled with some aggregation of results (e.g. computed parameter gradients) to be centrally processed (e.g. used by an optimizer). What is a good place to start if I want to implement things like that? I'm pretty new to Jax, I've read about its asynchronous execution paradigm and about functions like pmap, xmap and I've come across the mpi4jax library. However, I'm rather confused about what's the best way to implement things like that in Jax, and about how much "wheel reinvention" I could avoid. I appreciate any tips!
Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions