Skip to content

Commit 746c0a2

Browse files
authored
Modernize distributed/rpc/pipeline (#1385)
* Modernize RPC example * update * update
1 parent 7fce8bb commit 746c0a2

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

distributed/rpc/pipeline/main.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import threading
33
import time
4+
import warnings
45
from functools import wraps
56

67
import torch
@@ -9,11 +10,18 @@
910
import torch.distributed.rpc as rpc
1011
import torch.multiprocessing as mp
1112
import torch.optim as optim
12-
from torch.distributed.optim import DistributedOptimizer
1313
from torch.distributed.rpc import RRef
1414

1515
from torchvision.models.resnet import Bottleneck
1616

17+
# Suppress warnings that can't be fixed from user code
18+
warnings.filterwarnings("ignore",
19+
message="You are using a Backend .* as a ProcessGroup. This usage is deprecated",
20+
category=UserWarning)
21+
warnings.filterwarnings("ignore",
22+
message="networkx backend defined more than once: nx-loopback",
23+
category=RuntimeWarning)
24+
1725

1826
#########################################################
1927
# Define Model Parallel ResNet50 #
@@ -185,15 +193,34 @@ def parameter_rrefs(self):
185193
image_h = 128
186194

187195

196+
def create_optimizer_for_remote_params(worker_name, param_rrefs, lr=0.05):
197+
"""Create torch.compiled optimizers on each worker"""
198+
params = [p.to_here() for p in param_rrefs]
199+
opt = optim.SGD(params, lr=lr)
200+
opt.step = torch.compile(opt.step)
201+
return opt
202+
203+
188204
def run_master(split_size):
189205

190206
# put the two model parts on worker1 and worker2 respectively
191207
model = DistResNet50(split_size, ["worker1", "worker2"])
192208
loss_fn = nn.MSELoss()
193-
opt = DistributedOptimizer(
194-
optim.SGD,
195-
model.parameter_rrefs(),
196-
lr=0.05,
209+
210+
# Get parameter RRefs for each model shard
211+
p1_param_rrefs = model.p1_rref.remote().parameter_rrefs().to_here()
212+
p2_param_rrefs = model.p2_rref.remote().parameter_rrefs().to_here()
213+
214+
# Create optimizers on remote workers
215+
opt1_rref = rpc.remote(
216+
"worker1",
217+
create_optimizer_for_remote_params,
218+
args=("worker1", p1_param_rrefs)
219+
)
220+
opt2_rref = rpc.remote(
221+
"worker2",
222+
create_optimizer_for_remote_params,
223+
args=("worker2", p2_param_rrefs)
197224
)
198225

199226
one_hot_indices = torch.LongTensor(batch_size) \
@@ -213,7 +240,12 @@ def run_master(split_size):
213240
with dist_autograd.context() as context_id:
214241
outputs = model(inputs)
215242
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
216-
opt.step(context_id)
243+
244+
opt1_rref.rpc_sync().step()
245+
opt2_rref.rpc_sync().step()
246+
247+
opt1_rref.rpc_sync().zero_grad()
248+
opt2_rref.rpc_sync().zero_grad()
217249

218250

219251
def run_worker(rank, world_size, num_split):
@@ -245,6 +277,9 @@ def run_worker(rank, world_size, num_split):
245277

246278

247279
if __name__=="__main__":
280+
# Suppress torch compile profiler warnings
281+
os.environ['TORCH_LOGS'] = '-dynamo'
282+
248283
world_size = 3
249284
for num_split in [1, 2, 4, 8]:
250285
tik = time.time()
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
torch==1.9.0
2-
torchvision==0.7.0
1+
torch
2+
torchvision

0 commit comments

Comments
 (0)