Skip to content

Commit c4dc481

Browse files
authored
[T170073014] Rewrite distributed examples for Tensor Parallel, Sequence Parallel, 2D (FSDP + TP) (#1201)
* update requirements.txt * add torchrun support, move to init_device_mesh * update twod fully working * ensure proper dp group seeding for synth data * swiglu model added * sequential running of custom, auto, seq parallel models * streamline to 2D TP only for two_d_parallel example * sequence parallel working...needs init_device_mesh update * seq parallel now using init_device_mesh * tp and sp examples all working and updated * updates from code review * remove utils.py. Sample models created in example files * remove originals.py, leftover imports, various updates from code review feedback. * code linting via ruff * code formatting via ruff * move rank_log to utils.py, update example files * move logging imports and config to log_utils, update examples with new import * add gpu verification, update run_python_examples.sh * update min gpu = 4 for fsdp+tp * move gpu check to top of examples, but before import init_device_mesh to clear CI
1 parent f0d6fc9 commit c4dc481

File tree

9 files changed

+385
-278
lines changed

9 files changed

+385
-278
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import sys
2+
import torch
3+
import torch.distributed as dist
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8+
from torch.distributed.tensor.parallel import (
9+
parallelize_module,
10+
ColwiseParallel,
11+
RowwiseParallel,
12+
)
13+
14+
import os
15+
from log_utils import rank_log, get_logger, verify_min_gpu_count
16+
17+
18+
# ---- GPU check ------------
19+
_min_gpu_count = 4
20+
21+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
22+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
23+
sys.exit()
24+
# ---------------------------
25+
26+
from torch.distributed._tensor.device_mesh import init_device_mesh
27+
28+
29+
"""
30+
This is the script to test 2D Parallel which combines Tensor/Sequence
31+
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
32+
in the SPMD style. We show an E2E working flow from forward, backward
33+
and optimization.
34+
35+
We enabled Fully Sharded Data Parallel + Tensor Parallel in
36+
separate parallel dimensions:
37+
Data Parallel ("dp") across hosts
38+
Tensor Parallel ("tp") within each host
39+
40+
We use a simple diagram to illustrate below:
41+
42+
======================================================================
43+
------------ ------------ ------------ ------------
44+
| Host 1 | | Host 2 | | | | Host N |
45+
| 8 GPUs | | 8 GPUs | | | | 8 GPUs |
46+
| | | | | ... | | |
47+
| (TP) | | (TP) | | | | (TP) |
48+
|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7|
49+
| | | | | | | .., 8N-1]|
50+
| | | | | | | |
51+
------------ ------------ ------------ ------------
52+
FSDP:
53+
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
54+
======================================================================
55+
56+
More details can be seen in the slide:
57+
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
58+
"""
59+
60+
61+
def find_multiple(n: int, k: int) -> int:
62+
"""function to find resizing multiple for SwiGLU MLP"""
63+
if n % k == 0:
64+
return n
65+
return n + k - (n % k)
66+
67+
68+
class MLP_swiglu(nn.Module):
69+
"""SwiGLU to showcase a Llama style MLP model"""
70+
71+
def __init__(self, mlp_dim: int = 1024) -> None:
72+
super().__init__()
73+
hidden_dim = 4 * mlp_dim
74+
scaled_hidden = int(2 * hidden_dim / 3)
75+
rounded_hidden = find_multiple(scaled_hidden, 256)
76+
77+
self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
78+
self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
79+
self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False)
80+
81+
def forward(self, x: torch.Tensor) -> torch.Tensor:
82+
x = F.silu(self.in_proj(x)) * self.gate_proj(x)
83+
x = self.out_proj(x)
84+
return x
85+
86+
87+
"""
88+
Main body of the demo of a basic version of tensor parallel by using
89+
PyTorch native APIs.
90+
"""
91+
tp_size = 2
92+
logger = get_logger()
93+
94+
# understand world topology
95+
_rank = int(os.environ["RANK"])
96+
_world_size = int(os.environ["WORLD_SIZE"])
97+
98+
99+
print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.")
100+
assert (
101+
_world_size % tp_size == 0
102+
), f"World size {_world_size} needs to be divisible by TP size {tp_size}"
103+
104+
105+
# create a sharding plan based on the given world_size.
106+
dp_size = _world_size // tp_size
107+
108+
# Create a device mesh with 2 dimensions.
109+
# First dim is the data parallel dimension
110+
# Second dim is the tensor parallel dimension.
111+
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
112+
113+
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
114+
tp_mesh = device_mesh["tp"]
115+
dp_mesh = device_mesh["dp"]
116+
117+
# To support identical inputs for TP groups, we need the dp process group
118+
dp_pg = device_mesh.get_dim_groups()[0]
119+
120+
# For TP, input needs to be same across all TP ranks.
121+
# while for SP, input can be different across all ranks.
122+
# We will use dp_rank for setting the random seed
123+
# to mimic the behavior of the dataloader.
124+
dp_rank = dist.get_rank(dp_pg)
125+
126+
127+
# create model and move it to GPU with id rank
128+
_mlp_dim = 1024
129+
base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda")
130+
131+
132+
# Custom parallelization plan for the swiglu MLP model
133+
custom_tp_model = parallelize_module(
134+
module=base_model_swiglu,
135+
device_mesh=tp_mesh,
136+
parallelize_plan={
137+
"in_proj": ColwiseParallel(),
138+
"gate_proj": ColwiseParallel(),
139+
"out_proj": RowwiseParallel(),
140+
},
141+
)
142+
143+
rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n")
144+
145+
# Init FSDP using the dp device mesh
146+
sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True)
147+
148+
# Create an optimizer for the parallelized and sharded model.
149+
lr = 3e-3
150+
rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}")
151+
optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True)
152+
153+
# Training loop:
154+
# Perform a num of iterations of forward/backward
155+
# and optimizations for the sharded module.
156+
rank_log(_rank, logger, "\nStarting 2D training...")
157+
num_iterations = 10
158+
batch_size = 2
159+
160+
for i in range(num_iterations):
161+
# seeding with dp_rank to ensure identical inputs for TP groups
162+
torch.manual_seed(i + dp_rank)
163+
inp = torch.rand(batch_size, _mlp_dim, device="cuda")
164+
165+
output = sharded_model(inp)
166+
output.sum().backward()
167+
optimizer.step()
168+
rank_log(_rank, logger, f"2D iter {i} complete")
169+
170+
rank_log(_rank, logger, "2D training successfully completed!")
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import logging
2+
import torch
3+
4+
logging.basicConfig(
5+
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
6+
)
7+
8+
def get_logger():
9+
return logging.getLogger(__name__)
10+
11+
12+
def rank_log(_rank, logger, msg):
13+
"""helper function to log only on global rank 0"""
14+
if _rank == 0:
15+
logger.info(f" {msg}")
16+
17+
18+
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
19+
""" verification that we have at least 2 gpus to run dist examples """
20+
has_cuda = torch.cuda.is_available()
21+
gpu_count = torch.cuda.device_count()
22+
return has_cuda and gpu_count >= min_gpus
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Python dependencies required for running the example
22

33
--pre
4-
--extra-index-url https://download.pytorch.org/whl/nightly/cu113
5-
--extra-index-url https://download.pytorch.org/whl/nightly/cu116
6-
torch >= 1.14.0.dev0; sys_platform == "linux"
4+
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
5+
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
6+
torch >= 2.2.0.dev0; sys_platform == "linux"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
# To run samples:
3+
# bash run_example.sh {file_to_run.py} {num_gpus}
4+
# where file_to_run = example to launch. Default = 'fsdp_tp_example.py'
5+
# num_gpus = num local gpus to use (must be at least 2). Default = 4
6+
7+
# samples to run include:
8+
# sequence_parallel_example.py
9+
# tensor_parallel_example.py
10+
# fsdp_tp_example.py
11+
12+
echo "Launching ${1:-fsdp_tp_example.py} with ${2:-4} gpus"
13+
torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-fsdp_tp_example.py}
Lines changed: 87 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
1-
import argparse
2-
1+
import os
2+
import sys
33
import torch
4-
import torch.multiprocessing as mp
4+
import torch.nn as nn
5+
6+
from torch.distributed._tensor import Shard
7+
8+
from torch.distributed.tensor.parallel import (
9+
parallelize_module,
10+
ColwiseParallel,
11+
RowwiseParallel,
12+
)
13+
14+
from log_utils import rank_log, get_logger, verify_min_gpu_count
15+
16+
17+
# ---- GPU check ------------
18+
_min_gpu_count = 2
19+
20+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
21+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
22+
sys.exit()
23+
# ---------------------------
524

6-
from torch.distributed._tensor import DeviceMesh
7-
from torch.distributed.tensor.parallel import parallelize_module
8-
from utils import cleanup, setup, ToyModel
925

10-
try:
11-
from torch.distributed.tensor.parallel import (
12-
SequenceParallel
13-
)
14-
SP_AVAILABLE = True
15-
except BaseException as e:
16-
pass
26+
from torch.distributed._tensor.device_mesh import init_device_mesh
27+
1728

1829

1930
"""
@@ -33,51 +44,66 @@
3344
"""
3445

3546

36-
def demo_sp(rank, args):
37-
"""
38-
Main body of the demo of a basic version of sequence parallel by using
39-
PyTorch native APIs.
40-
"""
41-
print(f"Running SP example on rank {rank}.")
42-
setup(rank, args.world_size)
43-
44-
# create a sharding plan based on the given world_size.
45-
device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size))
46-
47-
# create model and move it to GPU with id rank
48-
model = ToyModel().cuda(rank)
49-
# Create a optimizer for the parallelized module.
50-
LR = 0.25
51-
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
52-
# Parallelize the module based on the given Parallel Style.
53-
model = parallelize_module(model, device_mesh, SequenceParallel())
54-
55-
# Perform a num of iterations of forward/backward
56-
# and optimizations for the sharded module.
57-
for _ in range(args.iter_nums):
58-
# For SP, input can be different across all ranks.
59-
inp = torch.rand(20, 10).cuda(rank)
60-
output = model(inp)
61-
output.sum().backward()
62-
optimizer.step()
63-
64-
cleanup()
65-
66-
67-
if __name__ == "__main__":
68-
n_gpus = torch.cuda.device_count()
69-
parser = argparse.ArgumentParser()
70-
# This is passed in via cmd
71-
parser.add_argument("--world_size", type=int, default=n_gpus)
72-
parser.add_argument("--iter_nums", type=int, default=10)
73-
args = parser.parse_args()
74-
# The main entry point is called directly without using subprocess
75-
if n_gpus < 2:
76-
print("Requires at least 2 GPUs to run.")
77-
elif not SP_AVAILABLE:
78-
print(
79-
"PyTorch doesn't have Sequence Parallelism available,"
80-
" need nightly build."
81-
)
82-
else:
83-
mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True)
47+
class ToyModel(nn.Module):
48+
"""MLP based model"""
49+
50+
def __init__(self):
51+
super().__init__()
52+
self.in_proj = nn.Linear(10, 32)
53+
self.relu = nn.ReLU()
54+
self.out_proj = nn.Linear(32, 5)
55+
56+
def forward(self, x):
57+
return self.out_proj(self.relu(self.in_proj(x)))
58+
59+
60+
"""
61+
Main body of the demo of a basic version of sequence parallel by using
62+
PyTorch native APIs.
63+
"""
64+
logger = get_logger()
65+
66+
# create a device mesh based on the given world_size.
67+
device_mesh = init_device_mesh(
68+
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
69+
)
70+
71+
_rank = device_mesh.get_rank()
72+
73+
print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.")
74+
75+
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
76+
77+
# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
78+
model = ToyModel().to("cuda")
79+
80+
# Custom parallelization plan for the model
81+
sp_model = parallelize_module(
82+
module=model,
83+
device_mesh=device_mesh,
84+
parallelize_plan={
85+
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
86+
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
87+
},
88+
)
89+
90+
91+
# Create a optimizer for the parallelized module.
92+
lr = 0.25
93+
optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True)
94+
95+
96+
# Perform a num of iterations of forward/backward
97+
# and optimizations for the sharded module.
98+
num_iters = 10
99+
rank_log(_rank, logger, "Sequence Parallel training starting...")
100+
101+
for i in range(num_iters):
102+
# For SP, input can be different across all ranks.
103+
inp = torch.rand(20, 10, device="cuda")
104+
output = sp_model(inp)
105+
output.sum().backward()
106+
optimizer.step()
107+
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")
108+
109+
rank_log(_rank, logger, "Sequence Parallel training completed!")

0 commit comments

Comments
 (0)