|
| 1 | +#!/usr/bin/python3 |
| 2 | +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. |
| 3 | +# See LICENSE for license information. |
| 4 | + |
| 5 | + |
| 6 | +import os |
| 7 | +import sys |
| 8 | +import argparse |
| 9 | + |
| 10 | +import transformer_engine.pytorch as te |
| 11 | +from transformer_engine.common.recipe import Float8CurrentScaling, Format, DelayedScaling, MXFP8BlockScaling |
| 12 | + |
| 13 | +import torch |
| 14 | +import torch.distributed as dist |
| 15 | +import torch.nn.functional as F |
| 16 | +from torch import nn, optim |
| 17 | +from torch.distributed import DeviceMesh |
| 18 | +from torch.distributed._composable.fsdp import fully_shard |
| 19 | +from torch.distributed.device_mesh import init_device_mesh |
| 20 | +from transformer_engine.pytorch import torch_version |
| 21 | +from transformer_engine.pytorch.fp8 import fp8_model_init |
| 22 | +from torch.nn.parallel import DistributedDataParallel as DDP |
| 23 | +from pathlib import Path |
| 24 | + |
| 25 | +class SimpleNet(nn.Module): |
| 26 | + def __init__(self, input_size, hidden_size, output_size, use_fsdp2=False): |
| 27 | + super(SimpleNet, self).__init__() |
| 28 | + |
| 29 | + # LayerNormLinear: fuses LayerNorm + Linear |
| 30 | + self.ln_linear = te.LayerNormLinear( |
| 31 | + in_features=input_size, |
| 32 | + out_features=hidden_size, |
| 33 | + eps=1e-5, |
| 34 | + use_fsdp2=use_fsdp2, |
| 35 | + keep_fp8_weight_transpose_cache=False |
| 36 | + ) |
| 37 | + |
| 38 | + # LayerNormMLP: fuses LayerNorm + FC1 + Activation + FC2 |
| 39 | + self.ln_mlp = te.LayerNormMLP( |
| 40 | + hidden_size=hidden_size, |
| 41 | + ffn_hidden_size=hidden_size * 4, # Typical 4x expansion |
| 42 | + use_fsdp2=use_fsdp2, |
| 43 | + keep_fp8_weight_transpose_cache=False |
| 44 | + ) |
| 45 | + |
| 46 | + # Regular Linear for final projection |
| 47 | + self.fc_out = te.Linear( |
| 48 | + hidden_size, |
| 49 | + output_size, |
| 50 | + use_fsdp2=use_fsdp2, |
| 51 | + keep_fp8_weight_transpose_cache=False |
| 52 | + ) |
| 53 | + |
| 54 | + def forward(self, x): |
| 55 | + # LayerNormLinear: applies LayerNorm then Linear |
| 56 | + x = self.ln_linear(x) |
| 57 | + |
| 58 | + # LayerNormMLP: applies LayerNorm + FC1 + GELU + FC2 |
| 59 | + x = self.ln_mlp(x) |
| 60 | + |
| 61 | + # Final Linear projection |
| 62 | + x = self.fc_out(x) |
| 63 | + |
| 64 | + return x |
| 65 | + |
| 66 | +def save_custom_attrs(module, _SKIP_KEYS = {"_data", "_module", "_transpose"}): |
| 67 | + custom_attrs = {} |
| 68 | + for name, param in module.named_parameters(): |
| 69 | + attrs = vars(param) |
| 70 | + custom_attrs[name] = {k: v for k, v in attrs.items()} |
| 71 | + for k in _SKIP_KEYS: |
| 72 | + custom_attrs[name].pop(k, None) |
| 73 | + return custom_attrs |
| 74 | + |
| 75 | + |
| 76 | +def restore_custom_attrs(module, custom_attrs): |
| 77 | + for name, param in module.named_parameters(): |
| 78 | + if name in custom_attrs: |
| 79 | + for attr_name, attr_value in custom_attrs[name].items(): |
| 80 | + setattr(param, attr_name, attr_value) |
| 81 | + |
| 82 | + |
| 83 | +def _parse_args(argv=None, namespace=None): |
| 84 | + parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") |
| 85 | + parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") |
| 86 | + parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") |
| 87 | + parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") |
| 88 | + parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") |
| 89 | + parser.add_argument( |
| 90 | + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." |
| 91 | + ) |
| 92 | + parser.add_argument( |
| 93 | + "--iter", type=int, default=10, help="Number of iterations for forward pass" |
| 94 | + ) |
| 95 | + parser.add_argument('--profile', action='store_true', |
| 96 | + help='Enable pytorch profiling.') |
| 97 | + parser.add_argument('--profile-step-start', type=int, default=6, |
| 98 | + help='Global step to start profiling.') |
| 99 | + parser.add_argument('--profile-step-end', type=int, default=7, |
| 100 | + help='Global step to stop profiling.') |
| 101 | + parser.add_argument('--profile-ranks', nargs='+', type=int, default=[0], |
| 102 | + help='Global ranks to profile.') |
| 103 | + parser.add_argument('--tensorboard-dir', type=str, default='./fsdp2_tensorboard', |
| 104 | + help='Write TensorBoard logs to this directory.') |
| 105 | + parser.add_argument('--gradients-save-file', type=str, default='all_iters.pt', |
| 106 | + help='Write all the gradients across all the iterations to this file.') |
| 107 | + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") |
| 108 | + parser.add_argument("--use-fsdp2", action='store_true', |
| 109 | + help='Enable New FSDP2 training.') |
| 110 | + parser.add_argument("--memory-profile", action='store_true', |
| 111 | + help='profile memory traces') |
| 112 | + parser.add_argument( |
| 113 | + "--recipe", |
| 114 | + type=str, |
| 115 | + choices=["delayed", "mxfp8", "current"], |
| 116 | + default="delayed", |
| 117 | + help="Select the training recipe to use: 'delayed', 'mxfp8', or 'current'." |
| 118 | + ) |
| 119 | + |
| 120 | + # Adding hsdp_dim as a list argument, comma-separated |
| 121 | + parser.add_argument( |
| 122 | + "--sharding-dims", |
| 123 | + type=int, |
| 124 | + nargs="+", |
| 125 | + help='FSDP/HSDP sharding dimensions ("replicate", "shard")', |
| 126 | + ) |
| 127 | + args = parser.parse_args(argv, namespace) |
| 128 | + if args.sharding_dims: |
| 129 | + assert len(args.sharding_dims) <= 2 |
| 130 | + return args |
| 131 | + |
| 132 | + |
| 133 | +sub_modules_to_wrap = [te.Linear, te.LayerNormLinear, te.LayerNormMLP] |
| 134 | + |
| 135 | + |
| 136 | +def _train(args): |
| 137 | + assert "TORCHELASTIC_RUN_ID" in os.environ |
| 138 | + WORLD_RANK = int(os.getenv("RANK", "0")) |
| 139 | + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) |
| 140 | + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) |
| 141 | + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) |
| 142 | + assert LOCAL_SIZE == WORLD_SIZE |
| 143 | + |
| 144 | + # Set device and initialize RNG states |
| 145 | + torch.cuda.set_device(WORLD_RANK) |
| 146 | + torch.manual_seed(args.seed) |
| 147 | + torch.cuda.manual_seed(args.seed) |
| 148 | + |
| 149 | + # Initialize torch.distributed global process group and get DP/TP groups |
| 150 | + dist_init_kwargs = { |
| 151 | + "backend": "nccl", |
| 152 | + "rank": WORLD_RANK, |
| 153 | + "world_size": WORLD_SIZE, |
| 154 | + } |
| 155 | + assert dist.is_nccl_available() |
| 156 | + dist.init_process_group(**dist_init_kwargs) |
| 157 | + nccl_world = dist.new_group(backend="nccl") |
| 158 | + device = torch.device(f"cuda:{LOCAL_RANK}") |
| 159 | + |
| 160 | + # FP8 Configuration |
| 161 | + if args.recipe == "current": |
| 162 | + fp8_recipe = Float8CurrentScaling() |
| 163 | + elif args.recipe == "mxfp8": |
| 164 | + fp8_recipe = MXFP8BlockScaling() |
| 165 | + elif args.recipe == "delayed": |
| 166 | + fp8_recipe = DelayedScaling() |
| 167 | + else: |
| 168 | + raise ValueError(f"Unsupported recipe: {args.recipe}") |
| 169 | + |
| 170 | + if args.memory_profile: |
| 171 | + torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all') |
| 172 | + if args.fp8_init: |
| 173 | + # Build the model with the specified context |
| 174 | + with fp8_model_init(enabled = True): |
| 175 | + model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) |
| 176 | + else: |
| 177 | + model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2) |
| 178 | + # Move the model to the correct device |
| 179 | + if not args.memory_profile: |
| 180 | + model.load_state_dict(torch.load('fsdp_model.pth')) |
| 181 | + model.to(device) |
| 182 | + |
| 183 | + # Creating a DeviceMesh for fully_shard |
| 184 | + world_size = int(WORLD_SIZE) |
| 185 | + device_ids = list(range(world_size)) |
| 186 | + |
| 187 | + # Apply FSDP/HSDP |
| 188 | + if args.use_fsdp2: |
| 189 | + custom_attrs = save_custom_attrs(model) |
| 190 | + if LOCAL_RANK == 0: |
| 191 | + print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") |
| 192 | + print(f"sharding-dims:{args.sharding_dims}") |
| 193 | + # Setup the sharding mesh for FSDP/HSDP |
| 194 | + if args.sharding_dims == None: # FSDP |
| 195 | + mesh = DeviceMesh("cuda", device_ids) |
| 196 | + elif len(args.sharding_dims) == 1: |
| 197 | + assert args.sharding_dims[0] == device_ids[-1] + 1 |
| 198 | + mesh = DeviceMesh("cuda", device_ids) |
| 199 | + elif len(args.sharding_dims) == 2: # HSDP |
| 200 | + assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 |
| 201 | + mesh = init_device_mesh( |
| 202 | + "cuda", |
| 203 | + (args.sharding_dims[0], args.sharding_dims[1]), |
| 204 | + mesh_dim_names=("replicate", "shard"), |
| 205 | + ) |
| 206 | + else: |
| 207 | + assert False |
| 208 | + for sub_module in model.modules(): |
| 209 | + if any( |
| 210 | + isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap |
| 211 | + ): |
| 212 | + fully_shard(sub_module, mesh=mesh) |
| 213 | + fully_shard(model, mesh=mesh, reshard_after_forward=True) |
| 214 | + restore_custom_attrs(model, custom_attrs) |
| 215 | + else: |
| 216 | + model = DDP(model, device_ids=[LOCAL_RANK]) |
| 217 | + |
| 218 | + optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3) |
| 219 | + |
| 220 | + input_path = Path("shared_input.pt") |
| 221 | + if input_path.exists(): |
| 222 | + input_data = torch.load(input_path).to(device) |
| 223 | + else: |
| 224 | + input_data = torch.randn(args.batch_size, args.input_size, requires_grad=True).to(device) |
| 225 | + torch.save(input_data.cpu(), input_path) |
| 226 | + print("Generated and saved shared input tensor.") |
| 227 | + |
| 228 | + out_tensors = [] |
| 229 | + prof = None |
| 230 | + if ( |
| 231 | + args.profile |
| 232 | + and torch.distributed.get_rank() in args.profile_ranks |
| 233 | + ): |
| 234 | + prof = torch.profiler.profile( |
| 235 | + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], |
| 236 | + schedule=torch.profiler.schedule( |
| 237 | + wait=max(args.profile_step_start - 1, 0), |
| 238 | + warmup=1 if args.profile_step_start > 0 else 0, |
| 239 | + active=args.profile_step_end - args.profile_step_start, |
| 240 | + repeat=1, |
| 241 | + ), |
| 242 | + on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), |
| 243 | + record_shapes=True, |
| 244 | + profile_memory=True, |
| 245 | + with_stack=True, |
| 246 | + ) |
| 247 | + prof.start() |
| 248 | + for iteration in range(args.iter): |
| 249 | + if LOCAL_RANK == 0: |
| 250 | + print(f"Starting iteration...{iteration}") |
| 251 | + if args.profile and torch.distributed.get_rank() in args.profile_ranks: |
| 252 | + prof.step() |
| 253 | + |
| 254 | + # Zero the parameter gradients |
| 255 | + optimizer.zero_grad() |
| 256 | + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): |
| 257 | + output = model(input_data) |
| 258 | + target = torch.randn(args.batch_size, args.output_size).to(device) |
| 259 | + loss = F.mse_loss(output, target) |
| 260 | + loss.backward() |
| 261 | + optimizer.step() |
| 262 | + if LOCAL_RANK == 0: |
| 263 | + print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") |
| 264 | + |
| 265 | + if not args.profile and not args.memory_profile: |
| 266 | + with torch.no_grad(): |
| 267 | + for name, p in model.named_parameters(): |
| 268 | + full_grad = None |
| 269 | + if p.grad is not None and hasattr(p.grad, 'full_tensor'): |
| 270 | + # This call is required to be executed on ALL ranks |
| 271 | + # to complete the collective communication. |
| 272 | + full_grad = p.grad.full_tensor().detach().clone() |
| 273 | + elif p.grad is not None: |
| 274 | + full_grad = p.grad.detach().clone() |
| 275 | + # 2. Only Rank 0 stores the result |
| 276 | + if LOCAL_RANK == 0 and p.requires_grad: |
| 277 | + out_tensors.append((name, full_grad)) |
| 278 | + if ( |
| 279 | + args.profile |
| 280 | + and iteration == args.profile_step_end |
| 281 | + and torch.distributed.get_rank() in args.profile_ranks |
| 282 | + ): |
| 283 | + prof.stop() |
| 284 | + |
| 285 | + if (not args.profile and not args.memory_profile) and LOCAL_RANK == 0: |
| 286 | + torch.save(out_tensors, args.gradients_save_file) |
| 287 | + |
| 288 | + if args.memory_profile: |
| 289 | + snapshot = torch.cuda.memory._snapshot() |
| 290 | + import pickle |
| 291 | + with open('memory_snapshot.pickle', 'wb') as f: |
| 292 | + pickle.dump(snapshot, f) |
| 293 | + # To disable memory history recording when no longer needed |
| 294 | + torch.cuda.memory._record_memory_history(enabled=None) |
| 295 | + |
| 296 | + # NOTE: In PyTorch < 2.6 there’s a teardown race where one rank may call |
| 297 | + # destroy_process_group() while other ranks still have in-flight NCCL ops, |
| 298 | + # which can trigger a NCCL/RCCL comm error. Newer releases (>= 2.6) fixed |
| 299 | + # this, but we kept a version-guarded barrier on older Torch for stability. |
| 300 | + if torch_version() < (2, 6, 0): |
| 301 | + dist.barrier(device_ids=[torch.cuda.current_device()]) |
| 302 | + dist.destroy_process_group() |
| 303 | + |
| 304 | + return 0 |
| 305 | + |
| 306 | + |
| 307 | +if __name__ == "__main__": |
| 308 | + sys.exit(_train(_parse_args())) |
0 commit comments