|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch.utils.data import DataLoader |
| 5 | +from torch.distributed._composable.fsdp.fully_shard import fully_shard |
| 6 | +from torch.distributed.device_mesh import DeviceMesh |
| 7 | + |
| 8 | +from torchao.float8 import convert_to_float8_training, Float8LinearConfig |
| 9 | + |
| 10 | +import lightning as L |
| 11 | +from lightning.fabric.strategies import ModelParallelStrategy |
| 12 | +from lightning.pytorch.demos import Transformer, WikiText2 |
| 13 | + |
| 14 | +from tqdm import tqdm |
| 15 | + |
| 16 | + |
| 17 | +def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 18 | + float8_config = Float8LinearConfig( |
| 19 | + # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly |
| 20 | + pad_inner_dim=True, |
| 21 | + ) |
| 22 | + |
| 23 | + def module_filter_fn(mod: torch.nn.Module, fqn: str): |
| 24 | + # we skip the decoder because it typically vocabulary size |
| 25 | + # is not divisible by 16 as required by float8 |
| 26 | + if fqn == "decoder": |
| 27 | + return False |
| 28 | + return True |
| 29 | + |
| 30 | + convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) |
| 31 | + |
| 32 | + for module in model.modules(): |
| 33 | + if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)): |
| 34 | + fully_shard(module, mesh=device_mesh) |
| 35 | + |
| 36 | + fully_shard(model, mesh=device_mesh) |
| 37 | + |
| 38 | + model = torch.compile(model) |
| 39 | + |
| 40 | + return model |
| 41 | + |
| 42 | + |
| 43 | +def train(): |
| 44 | + L.seed_everything(42) |
| 45 | + |
| 46 | + batch_size = 8 |
| 47 | + micro_batch_size = 1 |
| 48 | + |
| 49 | + dataset = WikiText2() |
| 50 | + dataloader = DataLoader(dataset, num_workers=8, batch_size=micro_batch_size) |
| 51 | + |
| 52 | + with torch.device("meta"): |
| 53 | + model = Transformer( |
| 54 | + vocab_size=dataset.vocab_size, |
| 55 | + nlayers=16, |
| 56 | + nhid=4096, |
| 57 | + ninp=1024, |
| 58 | + nhead=32, |
| 59 | + ) |
| 60 | + |
| 61 | + strategy = ModelParallelStrategy( |
| 62 | + data_parallel_size=4, |
| 63 | + tensor_parallel_size=1, |
| 64 | + parallelize_fn=configure_model |
| 65 | + ) |
| 66 | + |
| 67 | + fabric = L.Fabric(precision="bf16-true", strategy=strategy) |
| 68 | + fabric.launch() |
| 69 | + |
| 70 | + model = fabric.setup(model) |
| 71 | + |
| 72 | + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
| 73 | + optimizer = fabric.setup_optimizers(optimizer) |
| 74 | + |
| 75 | + dataloader = fabric.setup_dataloaders(dataloader) |
| 76 | + |
| 77 | + iterable = tqdm(enumerate(dataloader), total=len(dataloader)) if fabric.is_global_zero else enumerate(dataloader) |
| 78 | + |
| 79 | + for i, batch in iterable: |
| 80 | + input, target = batch |
| 81 | + |
| 82 | + is_accumulating = i % (batch_size // micro_batch_size) != 0 |
| 83 | + |
| 84 | + with fabric.no_backward_sync(model, enabled=is_accumulating): |
| 85 | + output = model(input, target) |
| 86 | + loss = F.nll_loss(output, target.view(-1)) |
| 87 | + fabric.backward(loss) |
| 88 | + |
| 89 | + if not is_accumulating: |
| 90 | + fabric.clip_gradients(model, optimizer, max_norm=1.0) |
| 91 | + optimizer.step() |
| 92 | + optimizer.zero_grad() |
| 93 | + |
| 94 | + if fabric.is_global_zero: |
| 95 | + iterable.set_postfix_str(f"train_loss={loss.item():.2f}") |
| 96 | + |
| 97 | + if i // (batch_size // micro_batch_size) > 100: |
| 98 | + break |
| 99 | + |
| 100 | + fabric.print(torch.cuda.memory_summary()) |
| 101 | + |
| 102 | + |
| 103 | +if __name__ == "__main__": |
| 104 | + torch.set_float32_matmul_precision('high') |
| 105 | + |
| 106 | + train() |
0 commit comments