|
2 | 2 | 2D Parallelism (Tensor Parallelism + FSDP)
|
3 | 3 | ##########################################
|
4 | 4 |
|
5 |
| -Content will be available soon. |
| 5 | +2D Parallelism combines Tensor Parallelism (TP) and Fully Sharded Data Parallelism (FSDP) to leverage the memory efficiency of FSDP and the computational scalability of TP. |
| 6 | +This hybrid approach balances the trade-offs of each method, optimizing memory usage and minimizing communication overhead, enabling the training of extremely large models on large GPU clusters. |
| 7 | + |
| 8 | +The :doc:`Tensor Parallelism documentation <tp>` and a general understanding of `FSDP <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`_ are a prerequisite for this tutorial. |
| 9 | + |
| 10 | +.. note:: This is an experimental feature. |
| 11 | + |
| 12 | + |
| 13 | +---- |
| 14 | + |
| 15 | + |
| 16 | +********************* |
| 17 | +Enable 2D parallelism |
| 18 | +********************* |
| 19 | + |
| 20 | +We will start off with the same feed forward example model as in the :doc:`Tensor Parallelism tutorial <tp>`. |
| 21 | + |
| 22 | +.. code-block:: python |
| 23 | +
|
| 24 | + import torch |
| 25 | + import torch.nn as nn |
| 26 | + import torch.nn.functional as F |
| 27 | +
|
| 28 | +
|
| 29 | + class FeedForward(nn.Module): |
| 30 | + def __init__(self, dim, hidden_dim): |
| 31 | + super().__init__() |
| 32 | + self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| 33 | + self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| 34 | + self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| 35 | +
|
| 36 | + def forward(self, x): |
| 37 | + return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 38 | +
|
| 39 | +Next, we define a function that applies the desired parallelism to our model. |
| 40 | +The function must take as first argument the model and as second argument the a :class:`~torch.distributed.device_mesh.DeviceMesh`. |
| 41 | +More on how the device mesh works later. |
| 42 | + |
| 43 | +.. code-block:: python |
| 44 | +
|
| 45 | + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel |
| 46 | + from torch.distributed.tensor.parallel import parallelize_module |
| 47 | + from torch.distributed._composable.fsdp.fully_shard import fully_shard |
| 48 | +
|
| 49 | + def parallelize_feedforward(model, device_mesh): |
| 50 | + # Lightning will set up a device mesh for you |
| 51 | + # Here, it is 2-dimensional |
| 52 | + tp_mesh = device_mesh["tensor_parallel"] |
| 53 | + dp_mesh = device_mesh["data_parallel"] |
| 54 | +
|
| 55 | + if tp_mesh.size() > 1: |
| 56 | + # Use PyTorch's distributed tensor APIs to parallelize the model |
| 57 | + plan = { |
| 58 | + "w1": ColwiseParallel(), |
| 59 | + "w2": RowwiseParallel(), |
| 60 | + "w3": ColwiseParallel(), |
| 61 | + } |
| 62 | + parallelize_module(model, tp_mesh, plan) |
| 63 | +
|
| 64 | + if dp_mesh.size() > 1: |
| 65 | + # Use PyTorch's FSDP2 APIs to parallelize the model |
| 66 | + fully_shard(model.w1, mesh=dp_mesh) |
| 67 | + fully_shard(model.w2, mesh=dp_mesh) |
| 68 | + fully_shard(model.w3, mesh=dp_mesh) |
| 69 | + fully_shard(model, mesh=dp_mesh) |
| 70 | +
|
| 71 | + return model |
| 72 | +
|
| 73 | +By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable. |
| 74 | +In addition to the tensor-parallel code from the :doc:`Tensor Parallelism tutorial <tp>`, this function also shards the model's parameters using FSDP along the data-parallel dimension. |
| 75 | + |
| 76 | +Finally, pass the parallelization function to the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` and configure the data-parallel and tensor-parallel sizes: |
| 77 | + |
| 78 | +.. code-block:: python |
| 79 | +
|
| 80 | + import lightning as L |
| 81 | + from lightning.fabric.strategies import ModelParallelStrategy |
| 82 | +
|
| 83 | + strategy = ModelParallelStrategy( |
| 84 | + parallelize_fn=parallelize_feedforward, |
| 85 | + # Define the size of the 2D parallelism |
| 86 | + # Set these to "auto" (default) to apply TP intra-node and FSDP inter-node |
| 87 | + data_parallel_size=2, |
| 88 | + tensor_parallel_size=2, |
| 89 | + ) |
| 90 | +
|
| 91 | + fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy) |
| 92 | + fabric.launch() |
| 93 | +
|
| 94 | +
|
| 95 | +In this example with 4 GPUs, Fabric will create a device mesh that groups GPU 0-1 and GPU 2-3 (2 groups because ``data_parallel_size=2``, and 2 GPUs per group because ``tensor_parallel_size=2``). |
| 96 | +Later on when ``fabric.setup(model)`` is called, each layer wrapped with FSDP (``fully_shard``) will be split into two shards, one for the GPU 0-1 group, and one for the GPU 2-3 group. |
| 97 | +Finally, the tensor parallelism will apply to each group, splitting the sharded tensor across the GPUs within each group. |
| 98 | + |
| 99 | + |
| 100 | +.. collapse:: Full training example (requires at least 4 GPUs). |
| 101 | + |
| 102 | + .. code-block:: python |
| 103 | +
|
| 104 | + import torch |
| 105 | + import torch.nn as nn |
| 106 | + import torch.nn.functional as F |
| 107 | +
|
| 108 | + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel |
| 109 | + from torch.distributed.tensor.parallel import parallelize_module |
| 110 | + from torch.distributed._composable.fsdp.fully_shard import fully_shard |
| 111 | +
|
| 112 | + import lightning as L |
| 113 | + from lightning.pytorch.demos.boring_classes import RandomDataset |
| 114 | + from lightning.fabric.strategies import ModelParallelStrategy |
| 115 | +
|
| 116 | +
|
| 117 | + class FeedForward(nn.Module): |
| 118 | + def __init__(self, dim, hidden_dim): |
| 119 | + super().__init__() |
| 120 | + self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| 121 | + self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| 122 | + self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
| 123 | +
|
| 124 | + def forward(self, x): |
| 125 | + return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 126 | +
|
| 127 | +
|
| 128 | + def parallelize_feedforward(model, device_mesh): |
| 129 | + # Lightning will set up a device mesh for you |
| 130 | + # Here, it is 2-dimensional |
| 131 | + tp_mesh = device_mesh["tensor_parallel"] |
| 132 | + dp_mesh = device_mesh["data_parallel"] |
| 133 | +
|
| 134 | + if tp_mesh.size() > 1: |
| 135 | + # Use PyTorch's distributed tensor APIs to parallelize the model |
| 136 | + plan = { |
| 137 | + "w1": ColwiseParallel(), |
| 138 | + "w2": RowwiseParallel(), |
| 139 | + "w3": ColwiseParallel(), |
| 140 | + } |
| 141 | + parallelize_module(model, tp_mesh, plan) |
| 142 | +
|
| 143 | + if dp_mesh.size() > 1: |
| 144 | + # Use PyTorch's FSDP2 APIs to parallelize the model |
| 145 | + fully_shard(model.w1, mesh=dp_mesh) |
| 146 | + fully_shard(model.w2, mesh=dp_mesh) |
| 147 | + fully_shard(model.w3, mesh=dp_mesh) |
| 148 | + fully_shard(model, mesh=dp_mesh) |
| 149 | +
|
| 150 | + return model |
| 151 | +
|
| 152 | +
|
| 153 | + strategy = ModelParallelStrategy( |
| 154 | + parallelize_fn=parallelize_feedforward, |
| 155 | + data_parallel_size=2, |
| 156 | + tensor_parallel_size=2, |
| 157 | + ) |
| 158 | +
|
| 159 | + fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy) |
| 160 | + fabric.launch() |
| 161 | +
|
| 162 | + # Initialize the model |
| 163 | + model = FeedForward(8192, 8192) |
| 164 | + model = fabric.setup(model) |
| 165 | +
|
| 166 | + # Define the optimizer |
| 167 | + optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True) |
| 168 | + optimizer = fabric.setup_optimizers(optimizer) |
| 169 | +
|
| 170 | + # Define dataset/dataloader |
| 171 | + dataset = RandomDataset(8192, 128) |
| 172 | + dataloader = torch.utils.data.DataLoader(dataset, batch_size=8) |
| 173 | + dataloader = fabric.setup_dataloaders(dataloader) |
| 174 | +
|
| 175 | + # Simplified training loop |
| 176 | + for i, batch in enumerate(dataloader): |
| 177 | + output = model(batch) |
| 178 | + loss = output.sum() |
| 179 | + fabric.backward(loss) |
| 180 | + optimizer.step() |
| 181 | + optimizer.zero_grad() |
| 182 | + fabric.print(f"Iteration {i} complete") |
| 183 | +
|
| 184 | + fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") |
| 185 | +
|
| 186 | +| |
| 187 | +
|
| 188 | +Beyond this toy example, we recommend you study our `LLM 2D Parallel Example (Llama 3) <https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel>`_. |
| 189 | + |
| 190 | + |
| 191 | +---- |
| 192 | + |
| 193 | + |
| 194 | +******************* |
| 195 | +Effective use cases |
| 196 | +******************* |
| 197 | + |
| 198 | +In the toy example above, the parallelization is configured to work within a single machine across multiple GPUs. |
| 199 | +However, in practice the main use case for 2D parallelism is in multi-node training, where one can effectively combine both methods to maximize throughput and model scale. |
| 200 | +Since tensor-parallelism requires blocking collective calls, fast GPU data transfers are essential to keep throughput high and therefore TP is typically applied across GPUs within a machine. |
| 201 | +On the other hand, FSDP by design has the advantage that it can overlap GPU transfers with the computation (it can prefetch layers). |
| 202 | +Hence, combining FSDP for inter-node parallelism and TP for intra-node parallelism is generally a good strategy to minimize both the latency and network bandwidth usage, making it possible to scale to much larger models than is possible with FSDP alone. |
| 203 | + |
| 204 | + |
| 205 | +.. code-block:: python |
| 206 | +
|
| 207 | + from lightning.fabric.strategies import ModelParallelStrategy |
| 208 | +
|
| 209 | + strategy = ModelParallelStrategy( |
| 210 | + # Default is "auto" |
| 211 | + # Applies TP intra-node and DP inter-node |
| 212 | + data_parallel_size="auto", |
| 213 | + tensor_parallel_size="auto", |
| 214 | + ) |
| 215 | +
|
| 216 | +
|
| 217 | +---- |
| 218 | + |
| 219 | + |
| 220 | +*************************** |
| 221 | +Data-loading considerations |
| 222 | +*************************** |
| 223 | + |
| 224 | +In a tensor-parallelized model, it is important that the model receives an identical input on each GPU that participates in the same tensor-parallel group. |
| 225 | +However, across the data-parallel dimension, the inputs should be different. |
| 226 | +In other words, if TP is applied within a node, and FSDP across nodes, each node must receive a different batch, but every GPU within the node gets the same batch of data. |
| 227 | + |
| 228 | +If you use a PyTorch data loader and set it up using :meth:`~lightning.fabric.fabric.Fabric.setup_dataloaders`, Fabric will automatically handle this for you by configuring the distributed sampler. |
| 229 | +However, when you shuffle data in your dataset or data loader, or when applying randomized transformations/augmentations in your data, you must still ensure that the seed is set appropriately. |
| 230 | + |
| 231 | + |
| 232 | +.. code-block:: python |
| 233 | +
|
| 234 | + import lightning as L |
| 235 | +
|
| 236 | + fabric = L.Fabric(...) |
| 237 | +
|
| 238 | + # Define dataset/dataloader |
| 239 | + # If there is randomness/augmentation in the dataset, fix the seed |
| 240 | + dataset = MyDataset(seed=42) |
| 241 | + dataloader = DataLoader(dataset, batch_size=8, shuffle=True) |
| 242 | +
|
| 243 | + # Fabric configures the sampler automatically for you such that |
| 244 | + # all batches in a tensor-parallel group are identical, |
| 245 | + # while still sharding the dataset across the data-parallel group |
| 246 | + dataloader = fabric.setup_dataloaders(dataloader) |
| 247 | +
|
| 248 | + for i, batch in enumerate(dataloader): |
| 249 | + ... |
| 250 | +
|
| 251 | +
|
| 252 | +
|
| 253 | +
|
| 254 | +---- |
| 255 | + |
| 256 | + |
| 257 | +********** |
| 258 | +Next steps |
| 259 | +********** |
| 260 | + |
| 261 | +.. raw:: html |
| 262 | + |
| 263 | + <div class="display-card-container"> |
| 264 | + <div class="row"> |
| 265 | + |
| 266 | +.. displayitem:: |
| 267 | + :header: LLM 2D Parallel Example |
| 268 | + :description: Full example how to combine TP + FSDP in a large language model (Llama 3) |
| 269 | + :col_css: col-md-4 |
| 270 | + :button_link: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel |
| 271 | + :height: 160 |
| 272 | + :tag: advanced |
| 273 | + |
| 274 | +.. displayitem:: |
| 275 | + :header: Pipeline Parallelism |
| 276 | + :description: Coming sooon |
| 277 | + :col_css: col-md-4 |
| 278 | + :height: 160 |
| 279 | + :tag: advanced |
| 280 | + |
| 281 | + |
| 282 | +.. raw:: html |
| 283 | + |
| 284 | + </div> |
| 285 | + </div> |
| 286 | + |
| 287 | +| |
0 commit comments