-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_ddp.py
More file actions
66 lines (50 loc) · 1.78 KB
/
test_ddp.py
File metadata and controls
66 lines (50 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def spmd_main():
# These are the parameters used to initialize the process group
env_dict = {
key: os.environ[key]
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
}
rank = int(env_dict['RANK'])
local_rank = int(env_dict['LOCAL_RANK'])
local_world_size = int(env_dict['LOCAL_WORLD_SIZE'])
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
acc = torch.accelerator.current_accelerator()
vendor_backend = torch.distributed.get_default_backend_for_device(acc)
torch.accelerator.set_device_index(rank)
torch.distributed.init_process_group(backend=vendor_backend)
demo_basic(rank)
# Tear down the process group
torch.distributed.destroy_process_group()
def demo_basic(rank):
print(
f"[{os.getpid()}] rank = {torch.distributed.get_rank()}, "
+ f"world_size = {torch.distributed.get_world_size()}"
)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
if __name__ == "__main__":
spmd_main()