|
| 1 | +import sys |
| 2 | +import torch |
| 3 | +import torch.distributed as dist |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 8 | +from torch.distributed.tensor.parallel import ( |
| 9 | + parallelize_module, |
| 10 | + ColwiseParallel, |
| 11 | + RowwiseParallel, |
| 12 | +) |
| 13 | + |
| 14 | +import os |
| 15 | +from log_utils import rank_log, get_logger, verify_min_gpu_count |
| 16 | + |
| 17 | + |
| 18 | +# ---- GPU check ------------ |
| 19 | +_min_gpu_count = 4 |
| 20 | + |
| 21 | +if not verify_min_gpu_count(min_gpus=_min_gpu_count): |
| 22 | + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") |
| 23 | + sys.exit() |
| 24 | +# --------------------------- |
| 25 | + |
| 26 | +from torch.distributed._tensor.device_mesh import init_device_mesh |
| 27 | + |
| 28 | + |
| 29 | +""" |
| 30 | +This is the script to test 2D Parallel which combines Tensor/Sequence |
| 31 | +parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model |
| 32 | +in the SPMD style. We show an E2E working flow from forward, backward |
| 33 | +and optimization. |
| 34 | +
|
| 35 | +We enabled Fully Sharded Data Parallel + Tensor Parallel in |
| 36 | +separate parallel dimensions: |
| 37 | + Data Parallel ("dp") across hosts |
| 38 | + Tensor Parallel ("tp") within each host |
| 39 | +
|
| 40 | + We use a simple diagram to illustrate below: |
| 41 | +
|
| 42 | +====================================================================== |
| 43 | +------------ ------------ ------------ ------------ |
| 44 | +| Host 1 | | Host 2 | | | | Host N | |
| 45 | +| 8 GPUs | | 8 GPUs | | | | 8 GPUs | |
| 46 | +| | | | | ... | | | |
| 47 | +| (TP) | | (TP) | | | | (TP) | |
| 48 | +|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7| |
| 49 | +| | | | | | | .., 8N-1]| |
| 50 | +| | | | | | | | |
| 51 | +------------ ------------ ------------ ------------ |
| 52 | +FSDP: |
| 53 | +[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1] |
| 54 | +====================================================================== |
| 55 | +
|
| 56 | +More details can be seen in the slide: |
| 57 | +https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ |
| 58 | +""" |
| 59 | + |
| 60 | + |
| 61 | +def find_multiple(n: int, k: int) -> int: |
| 62 | + """function to find resizing multiple for SwiGLU MLP""" |
| 63 | + if n % k == 0: |
| 64 | + return n |
| 65 | + return n + k - (n % k) |
| 66 | + |
| 67 | + |
| 68 | +class MLP_swiglu(nn.Module): |
| 69 | + """SwiGLU to showcase a Llama style MLP model""" |
| 70 | + |
| 71 | + def __init__(self, mlp_dim: int = 1024) -> None: |
| 72 | + super().__init__() |
| 73 | + hidden_dim = 4 * mlp_dim |
| 74 | + scaled_hidden = int(2 * hidden_dim / 3) |
| 75 | + rounded_hidden = find_multiple(scaled_hidden, 256) |
| 76 | + |
| 77 | + self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) |
| 78 | + self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) |
| 79 | + self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False) |
| 80 | + |
| 81 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 82 | + x = F.silu(self.in_proj(x)) * self.gate_proj(x) |
| 83 | + x = self.out_proj(x) |
| 84 | + return x |
| 85 | + |
| 86 | + |
| 87 | +""" |
| 88 | +Main body of the demo of a basic version of tensor parallel by using |
| 89 | +PyTorch native APIs. |
| 90 | +""" |
| 91 | +tp_size = 2 |
| 92 | +logger = get_logger() |
| 93 | + |
| 94 | +# understand world topology |
| 95 | +_rank = int(os.environ["RANK"]) |
| 96 | +_world_size = int(os.environ["WORLD_SIZE"]) |
| 97 | + |
| 98 | + |
| 99 | +print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.") |
| 100 | +assert ( |
| 101 | + _world_size % tp_size == 0 |
| 102 | +), f"World size {_world_size} needs to be divisible by TP size {tp_size}" |
| 103 | + |
| 104 | + |
| 105 | +# create a sharding plan based on the given world_size. |
| 106 | +dp_size = _world_size // tp_size |
| 107 | + |
| 108 | +# Create a device mesh with 2 dimensions. |
| 109 | +# First dim is the data parallel dimension |
| 110 | +# Second dim is the tensor parallel dimension. |
| 111 | +device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) |
| 112 | + |
| 113 | +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") |
| 114 | +tp_mesh = device_mesh["tp"] |
| 115 | +dp_mesh = device_mesh["dp"] |
| 116 | + |
| 117 | +# To support identical inputs for TP groups, we need the dp process group |
| 118 | +dp_pg = device_mesh.get_dim_groups()[0] |
| 119 | + |
| 120 | +# For TP, input needs to be same across all TP ranks. |
| 121 | +# while for SP, input can be different across all ranks. |
| 122 | +# We will use dp_rank for setting the random seed |
| 123 | +# to mimic the behavior of the dataloader. |
| 124 | +dp_rank = dist.get_rank(dp_pg) |
| 125 | + |
| 126 | + |
| 127 | +# create model and move it to GPU with id rank |
| 128 | +_mlp_dim = 1024 |
| 129 | +base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda") |
| 130 | + |
| 131 | + |
| 132 | +# Custom parallelization plan for the swiglu MLP model |
| 133 | +custom_tp_model = parallelize_module( |
| 134 | + module=base_model_swiglu, |
| 135 | + device_mesh=tp_mesh, |
| 136 | + parallelize_plan={ |
| 137 | + "in_proj": ColwiseParallel(), |
| 138 | + "gate_proj": ColwiseParallel(), |
| 139 | + "out_proj": RowwiseParallel(), |
| 140 | + }, |
| 141 | +) |
| 142 | + |
| 143 | +rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n") |
| 144 | + |
| 145 | +# Init FSDP using the dp device mesh |
| 146 | +sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True) |
| 147 | + |
| 148 | +# Create an optimizer for the parallelized and sharded model. |
| 149 | +lr = 3e-3 |
| 150 | +rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}") |
| 151 | +optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True) |
| 152 | + |
| 153 | +# Training loop: |
| 154 | +# Perform a num of iterations of forward/backward |
| 155 | +# and optimizations for the sharded module. |
| 156 | +rank_log(_rank, logger, "\nStarting 2D training...") |
| 157 | +num_iterations = 10 |
| 158 | +batch_size = 2 |
| 159 | + |
| 160 | +for i in range(num_iterations): |
| 161 | + # seeding with dp_rank to ensure identical inputs for TP groups |
| 162 | + torch.manual_seed(i + dp_rank) |
| 163 | + inp = torch.rand(batch_size, _mlp_dim, device="cuda") |
| 164 | + |
| 165 | + output = sharded_model(inp) |
| 166 | + output.sum().backward() |
| 167 | + optimizer.step() |
| 168 | + rank_log(_rank, logger, f"2D iter {i} complete") |
| 169 | + |
| 170 | +rank_log(_rank, logger, "2D training successfully completed!") |
0 commit comments