-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_fsdp_toy.py
More file actions
84 lines (69 loc) · 2.42 KB
/
train_fsdp_toy.py
File metadata and controls
84 lines (69 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.optim as optim
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--backend", type=str, default = "nccl")
args = parser.parse_args()
dist.init_process_group(backend = args.backend)
local_rank = int(os.environ.get("LOCAL_RANK",0))
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(10,64),
nn.ReLU(),
nn.Linear(64,2)
)
def forward(self,x):
return self.net(x)
model = TestModel().to(device)
model = FSDP(model)
x = torch.randn(64,10).to(device)
y = torch.randint(0,2,(64,)).to(device)
dataset = TensorDataset(x,y)
loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = True)
optimizer = optim.Adam(model.parameters(), lr = args.lr)
criterion = nn.CrossEntropyLoss()
import time
start_time = time.time()
torch.cuda.reset_peak_memory_stats()
for epoch in range(args.epochs):
model.train()
total_loss = 0.0
# Reset timing and memory tracking for each epoch
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
for step, (inputs, labels) in enumerate(loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
if step % 5 == 0 and dist.get_rank() == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}")
# Compute epoch metrics
avg_loss = total_loss / len(loader)
elapsed = time.time() - start_time
throughput = len(loader.dataset) / elapsed
peak_mem = torch.cuda.max_memory_allocated() / 1e6
if dist.get_rank() == 0:
print(
f"Epoch {epoch} finished. "
f"Avg Loss: {avg_loss:.4f}, "
f"Time: {elapsed:.2f}s, "
f"Throughput: {throughput:.2f} samples/sec, "
f"Peak Mem: {peak_mem:.2f} MB"
)
if dist.is_initialized():
dist.destroy_process_group()