Pipelining, xmap, and the global device array #9696
Replies: 4 comments 2 replies
-
TPUs don't really have a notion of "node" other than the literal physical host machine; the TPU chips are connected in a homogeneous ICI network that doesn't care about host boundaries. That's actually why we made GlobalDeviceArray: to provide an abstraction for inputs and outputs of global computations that span many hosts, without forcing users to special-case host boundaries. |
Beta Was this translation helpful? Give feedback.
-
FYI, there's an implementation of pipelining and PipelinedTransformer that works on TPU at https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/pipeline.py Also, see GSPMD paper: https://arxiv.org/abs/2105.04663 |
Beta Was this translation helpful? Give feedback.
-
Thanks for the quick response - just took some time to read everything. @zhangqiaorjc - thank you for that paper and code! The shifting buffer is exactly the concept I was getting at above, but its much better expressed than in my bullets :) . Collective-permute is the operation I was looking for to shift the buffer. I couldn't find a use of lax.ppermute in the code, but in the tensorflow gshard comments they mention that identical code to this line induces a collective permute - which makes sense. @jekbradbury - I think this reveals a point where my understanding of the multi-host abstraction cracks a little. In particular - in the transition from cpu to the mesh. I think I'd be able to figure this out if I was able to play around with a pod slice - but I don't currently have access to one. I was hoping to write code that correctly translates anyway, so I was wondering if I could check my understanding here. The psum pod slice example, translates trivially to a data parallel process. Instead of loading an array of ones each host could load a random batch from some dataset, sum grads on whichever computation was performed etc. Now, for an contrived example - but one which (I think?) applies to the pipeline instance (ignoring GPipe style microbatching for the moment, lets just use an inefficient simple example). What if you wanted to reshape the global device array into a [32, 1] mesh - a single long pipeline of dims [stage, batch]. Our program is simple: x *= y_i. By the end of the pipeline, a 1 which enters in the first device should be equal to 1 * y_0 * y_1 ... * y_n One way of achieving this would be to use the shifting buffer described above. Construct a [32, 1] tensor where the zero-th entry is 1 and all others are 0, and x_map the first dimension across the stages dim. Same goes for the params array y. This allows us to run the same program (y ~= N transformer blocks w/ their own parameters) with different params on a input value that is moving along the pipeline. At each iteration (using scan), you could ppermute the data along the input tensor so that the result of the previous timestep is one index greater along the 'stages' dim. This could also be achieved via the slice/pad from the code example provided. This is my intuition of how it 'should work'. Where I'm uncertain is how to handle the initial construction of the input data tensor. E.g., should each host load exactly the same batch (this [1,0,0,0,0,0...] tensor described), then simply apply x_map across stages in the global device array? It feels slightly strange for each host to load the same data, so I don't feel confident I'm correct. PS: In the Jax doc search it seems to only include results for whole words. E.g. If I search permute, lax's ppermute (collective) doesn't show up - I need to search ppmermute (or collective). This could well be an intentional choice - just flagging if it isn't. I also have a question about AllReduce time complexity w.r.t # paritions (but please don't feel any pressure to answer if this isn't the right place). This may reveal my ignorance about the trade offs between tensor / pipeline parallelism - but it was my understanding that tensor parallelism over a large number of devices induced large communication overheads due to the all reduces every layer. Fig 9 in https://arxiv.org/pdf/2006.16668.pdf shows that it simply has constant time on TPUs even up to 256 shards - though this is above 16 partitions (does it apply below 8?). I had assumed this was one of the reasons PP was useful - that it scaled constantly in cost with the number of stages with relatively minimal communication overhead (at the expense of bubbles). Do you know of any papers discuss these trade-offs with hard numbers? Its obviously hardware specific, but I'd like to build a stronger intuition the magnitude of time cost. Papers like GPipe motivate with this reasoning, but don't show graphs! |
Beta Was this translation helpful? Give feedback.
-
So is it the case that for pipeline parallelism specifically mpi4jax is the correct tool? Or can that sort of thing be done in "pure" jax? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi team!
I'm exploring different scaling methods in jax - and am interested in understanding what the best way to implement pipeline parallelism with multiple TPU nodes is.
With torch/nvidia - I believe NCCL allows for multi-node inter GPU communication, and was wondering if there was a jax equivalent that leverages the ICI links that exist between TPU nodes when one has access to a v32 or greater and allows for send/recv style operations in addition to global collective mean, sum.
Initially - I thought that mpi4jax would suit - which it does when using cpu/gpu (it allows explicit send/recv between each node). However it doesn't have TPU support.
Ray also works for coordinating TPU nodes (ala Ben Wang's mesh transformer repo) - but as I understand the CPU copies required to communicate between nodes lower the efficiency.
I thought that that xmap / the global device array would allow for this extremely elegantly. E.g.
This seems like it would work quite well within a node of 8 machines (where it makes sense to easily pass data outputs in a loop from one set of devices to another) - and could improve efficiency somewhat over an exclusively tensor parallel architecture (Eleuther's 20b model found that the optimal PP/MP split within a node was 4,2).
However, my read of the current docs doesn't indicate an easy way of sending data between nodes - so I wasn't sure how to extend pipeline parallelism to an efficient inter-node setup (assuming one wanted to be faster than Ray), leveraging the ICI network that collective operations like mean or sum do.
Beta Was this translation helpful? Give feedback.
All reactions