Skip to content

Commit 341474a

Browse files
authored
(8/n) Support 2D Parallelism - 2D Parallel Fabric Docs (#19887)
1 parent 8fc7b4a commit 341474a

File tree

4 files changed

+292
-8
lines changed

4 files changed

+292
-8
lines changed

docs/source-fabric/advanced/model_init.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ When loading a model from a checkpoint, for example when fine-tuning, set ``empt
6161
----
6262

6363

64-
********************************************
65-
Model-parallel training (FSDP and DeepSpeed)
66-
********************************************
64+
***************************************************
65+
Model-parallel training (FSDP, TP, DeepSpeed, etc.)
66+
***************************************************
6767

68-
When training sharded models with :doc:`FSDP <model_parallel/fsdp>` or DeepSpeed, using :meth:`~lightning.fabric.fabric.Fabric.init_module` is necessary in most cases because otherwise model initialization gets very slow (minutes) or (and that's more likely) you run out of CPU memory due to the size of the model.
68+
When training distributed models with :doc:`FSDP/TP <model_parallel/index>` or DeepSpeed, using :meth:`~lightning.fabric.fabric.Fabric.init_module` is necessary in most cases because otherwise model initialization gets very slow (minutes) or (and that's more likely) you run out of CPU memory due to the size of the model.
6969

7070
.. code-block:: python
7171
72-
# Recommended for FSDP and DeepSpeed
72+
# Recommended for FSDP, TP and DeepSpeed
7373
with fabric.init_module(empty_init=True):
7474
model = GPT3() # parameters are placed on the meta-device
7575
@@ -81,4 +81,4 @@ When training sharded models with :doc:`FSDP <model_parallel/fsdp>` or DeepSpeed
8181
8282
.. note::
8383
Empty-init is experimental and the behavior may change in the future.
84-
For FSDP on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).
84+
For distributed models on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).

docs/source-fabric/advanced/model_parallel/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ Get started
8282
.. displayitem::
8383
:header: Pipeline Parallelism
8484
:description: Coming soon
85-
:button_link:
8685
:col_css: col-md-4
8786
:height: 180
8887
:tag: advanced

docs/source-fabric/advanced/model_parallel/tp.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ In Fabric, define a function that applies the tensor parallelism to the model:
110110
parallelize_module(model, tp_mesh, plan)
111111
return model
112112
113+
By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable.
113114
Next, configure the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` in Fabric:
114115

115116
.. code-block:: python
@@ -222,6 +223,8 @@ Beyond this toy example, we recommend you study our `LLM Tensor Parallel Example
222223
----
223224

224225

226+
.. _tp-data-loading:
227+
225228
***************************
226229
Data-loading considerations
227230
***************************

docs/source-fabric/advanced/model_parallel/tp_fsdp.rst

Lines changed: 283 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,286 @@
22
2D Parallelism (Tensor Parallelism + FSDP)
33
##########################################
44

5-
Content will be available soon.
5+
2D Parallelism combines Tensor Parallelism (TP) and Fully Sharded Data Parallelism (FSDP) to leverage the memory efficiency of FSDP and the computational scalability of TP.
6+
This hybrid approach balances the trade-offs of each method, optimizing memory usage and minimizing communication overhead, enabling the training of extremely large models on large GPU clusters.
7+
8+
The :doc:`Tensor Parallelism documentation <tp>` and a general understanding of `FSDP <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`_ are a prerequisite for this tutorial.
9+
10+
.. note:: This is an experimental feature.
11+
12+
13+
----
14+
15+
16+
*********************
17+
Enable 2D parallelism
18+
*********************
19+
20+
We will start off with the same feed forward example model as in the :doc:`Tensor Parallelism tutorial <tp>`.
21+
22+
.. code-block:: python
23+
24+
import torch
25+
import torch.nn as nn
26+
import torch.nn.functional as F
27+
28+
29+
class FeedForward(nn.Module):
30+
def __init__(self, dim, hidden_dim):
31+
super().__init__()
32+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
33+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
34+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
35+
36+
def forward(self, x):
37+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
38+
39+
Next, we define a function that applies the desired parallelism to our model.
40+
The function must take as first argument the model and as second argument the a :class:`~torch.distributed.device_mesh.DeviceMesh`.
41+
More on how the device mesh works later.
42+
43+
.. code-block:: python
44+
45+
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
46+
from torch.distributed.tensor.parallel import parallelize_module
47+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
48+
49+
def parallelize_feedforward(model, device_mesh):
50+
# Lightning will set up a device mesh for you
51+
# Here, it is 2-dimensional
52+
tp_mesh = device_mesh["tensor_parallel"]
53+
dp_mesh = device_mesh["data_parallel"]
54+
55+
if tp_mesh.size() > 1:
56+
# Use PyTorch's distributed tensor APIs to parallelize the model
57+
plan = {
58+
"w1": ColwiseParallel(),
59+
"w2": RowwiseParallel(),
60+
"w3": ColwiseParallel(),
61+
}
62+
parallelize_module(model, tp_mesh, plan)
63+
64+
if dp_mesh.size() > 1:
65+
# Use PyTorch's FSDP2 APIs to parallelize the model
66+
fully_shard(model.w1, mesh=dp_mesh)
67+
fully_shard(model.w2, mesh=dp_mesh)
68+
fully_shard(model.w3, mesh=dp_mesh)
69+
fully_shard(model, mesh=dp_mesh)
70+
71+
return model
72+
73+
By writing the parallelization code in a separate function rather than hardcoding it into the model, we keep the original source code clean and maintainable.
74+
In addition to the tensor-parallel code from the :doc:`Tensor Parallelism tutorial <tp>`, this function also shards the model's parameters using FSDP along the data-parallel dimension.
75+
76+
Finally, pass the parallelization function to the :class:`~lightning.fabric.strategies.model_parallel.ModelParallelStrategy` and configure the data-parallel and tensor-parallel sizes:
77+
78+
.. code-block:: python
79+
80+
import lightning as L
81+
from lightning.fabric.strategies import ModelParallelStrategy
82+
83+
strategy = ModelParallelStrategy(
84+
parallelize_fn=parallelize_feedforward,
85+
# Define the size of the 2D parallelism
86+
# Set these to "auto" (default) to apply TP intra-node and FSDP inter-node
87+
data_parallel_size=2,
88+
tensor_parallel_size=2,
89+
)
90+
91+
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
92+
fabric.launch()
93+
94+
95+
In this example with 4 GPUs, Fabric will create a device mesh that groups GPU 0-1 and GPU 2-3 (2 groups because ``data_parallel_size=2``, and 2 GPUs per group because ``tensor_parallel_size=2``).
96+
Later on when ``fabric.setup(model)`` is called, each layer wrapped with FSDP (``fully_shard``) will be split into two shards, one for the GPU 0-1 group, and one for the GPU 2-3 group.
97+
Finally, the tensor parallelism will apply to each group, splitting the sharded tensor across the GPUs within each group.
98+
99+
100+
.. collapse:: Full training example (requires at least 4 GPUs).
101+
102+
.. code-block:: python
103+
104+
import torch
105+
import torch.nn as nn
106+
import torch.nn.functional as F
107+
108+
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
109+
from torch.distributed.tensor.parallel import parallelize_module
110+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
111+
112+
import lightning as L
113+
from lightning.pytorch.demos.boring_classes import RandomDataset
114+
from lightning.fabric.strategies import ModelParallelStrategy
115+
116+
117+
class FeedForward(nn.Module):
118+
def __init__(self, dim, hidden_dim):
119+
super().__init__()
120+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
121+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
122+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
123+
124+
def forward(self, x):
125+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
126+
127+
128+
def parallelize_feedforward(model, device_mesh):
129+
# Lightning will set up a device mesh for you
130+
# Here, it is 2-dimensional
131+
tp_mesh = device_mesh["tensor_parallel"]
132+
dp_mesh = device_mesh["data_parallel"]
133+
134+
if tp_mesh.size() > 1:
135+
# Use PyTorch's distributed tensor APIs to parallelize the model
136+
plan = {
137+
"w1": ColwiseParallel(),
138+
"w2": RowwiseParallel(),
139+
"w3": ColwiseParallel(),
140+
}
141+
parallelize_module(model, tp_mesh, plan)
142+
143+
if dp_mesh.size() > 1:
144+
# Use PyTorch's FSDP2 APIs to parallelize the model
145+
fully_shard(model.w1, mesh=dp_mesh)
146+
fully_shard(model.w2, mesh=dp_mesh)
147+
fully_shard(model.w3, mesh=dp_mesh)
148+
fully_shard(model, mesh=dp_mesh)
149+
150+
return model
151+
152+
153+
strategy = ModelParallelStrategy(
154+
parallelize_fn=parallelize_feedforward,
155+
data_parallel_size=2,
156+
tensor_parallel_size=2,
157+
)
158+
159+
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
160+
fabric.launch()
161+
162+
# Initialize the model
163+
model = FeedForward(8192, 8192)
164+
model = fabric.setup(model)
165+
166+
# Define the optimizer
167+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=True)
168+
optimizer = fabric.setup_optimizers(optimizer)
169+
170+
# Define dataset/dataloader
171+
dataset = RandomDataset(8192, 128)
172+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
173+
dataloader = fabric.setup_dataloaders(dataloader)
174+
175+
# Simplified training loop
176+
for i, batch in enumerate(dataloader):
177+
output = model(batch)
178+
loss = output.sum()
179+
fabric.backward(loss)
180+
optimizer.step()
181+
optimizer.zero_grad()
182+
fabric.print(f"Iteration {i} complete")
183+
184+
fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
185+
186+
|
187+
188+
Beyond this toy example, we recommend you study our `LLM 2D Parallel Example (Llama 3) <https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel>`_.
189+
190+
191+
----
192+
193+
194+
*******************
195+
Effective use cases
196+
*******************
197+
198+
In the toy example above, the parallelization is configured to work within a single machine across multiple GPUs.
199+
However, in practice the main use case for 2D parallelism is in multi-node training, where one can effectively combine both methods to maximize throughput and model scale.
200+
Since tensor-parallelism requires blocking collective calls, fast GPU data transfers are essential to keep throughput high and therefore TP is typically applied across GPUs within a machine.
201+
On the other hand, FSDP by design has the advantage that it can overlap GPU transfers with the computation (it can prefetch layers).
202+
Hence, combining FSDP for inter-node parallelism and TP for intra-node parallelism is generally a good strategy to minimize both the latency and network bandwidth usage, making it possible to scale to much larger models than is possible with FSDP alone.
203+
204+
205+
.. code-block:: python
206+
207+
from lightning.fabric.strategies import ModelParallelStrategy
208+
209+
strategy = ModelParallelStrategy(
210+
# Default is "auto"
211+
# Applies TP intra-node and DP inter-node
212+
data_parallel_size="auto",
213+
tensor_parallel_size="auto",
214+
)
215+
216+
217+
----
218+
219+
220+
***************************
221+
Data-loading considerations
222+
***************************
223+
224+
In a tensor-parallelized model, it is important that the model receives an identical input on each GPU that participates in the same tensor-parallel group.
225+
However, across the data-parallel dimension, the inputs should be different.
226+
In other words, if TP is applied within a node, and FSDP across nodes, each node must receive a different batch, but every GPU within the node gets the same batch of data.
227+
228+
If you use a PyTorch data loader and set it up using :meth:`~lightning.fabric.fabric.Fabric.setup_dataloaders`, Fabric will automatically handle this for you by configuring the distributed sampler.
229+
However, when you shuffle data in your dataset or data loader, or when applying randomized transformations/augmentations in your data, you must still ensure that the seed is set appropriately.
230+
231+
232+
.. code-block:: python
233+
234+
import lightning as L
235+
236+
fabric = L.Fabric(...)
237+
238+
# Define dataset/dataloader
239+
# If there is randomness/augmentation in the dataset, fix the seed
240+
dataset = MyDataset(seed=42)
241+
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
242+
243+
# Fabric configures the sampler automatically for you such that
244+
# all batches in a tensor-parallel group are identical,
245+
# while still sharding the dataset across the data-parallel group
246+
dataloader = fabric.setup_dataloaders(dataloader)
247+
248+
for i, batch in enumerate(dataloader):
249+
...
250+
251+
252+
253+
254+
----
255+
256+
257+
**********
258+
Next steps
259+
**********
260+
261+
.. raw:: html
262+
263+
<div class="display-card-container">
264+
<div class="row">
265+
266+
.. displayitem::
267+
:header: LLM 2D Parallel Example
268+
:description: Full example how to combine TP + FSDP in a large language model (Llama 3)
269+
:col_css: col-md-4
270+
:button_link: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/fabric/tensor_parallel
271+
:height: 160
272+
:tag: advanced
273+
274+
.. displayitem::
275+
:header: Pipeline Parallelism
276+
:description: Coming sooon
277+
:col_css: col-md-4
278+
:height: 160
279+
:tag: advanced
280+
281+
282+
.. raw:: html
283+
284+
</div>
285+
</div>
286+
287+
|

0 commit comments

Comments
 (0)