Skip to content

Commit 89e1215

Browse files
committed
Modernize RPC example
1 parent f1723eb commit 89e1215

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

distributed/rpc/pipeline/main.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch.distributed.rpc as rpc
1010
import torch.multiprocessing as mp
1111
import torch.optim as optim
12-
from torch.distributed.optim import DistributedOptimizer
1312
from torch.distributed.rpc import RRef
1413

1514
from torchvision.models.resnet import Bottleneck
@@ -185,15 +184,35 @@ def parameter_rrefs(self):
185184
image_h = 128
186185

187186

187+
def create_optimizer_for_remote_params(worker_name, param_rrefs, lr=0.05):
188+
"""Create optimizer on remote worker for given parameters"""
189+
params = [p.to_here() for p in param_rrefs]
190+
opt = optim.SGD(params, lr=lr)
191+
# Use torch.compile to optimize the optimizer step
192+
opt.step = torch.compile(opt.step)
193+
return opt
194+
195+
188196
def run_master(split_size):
189197

190198
# put the two model parts on worker1 and worker2 respectively
191199
model = DistResNet50(split_size, ["worker1", "worker2"])
192200
loss_fn = nn.MSELoss()
193-
opt = DistributedOptimizer(
194-
optim.SGD,
195-
model.parameter_rrefs(),
196-
lr=0.05,
201+
202+
# Get parameter RRefs for each model shard
203+
p1_param_rrefs = model.p1_rref.remote().parameter_rrefs().to_here()
204+
p2_param_rrefs = model.p2_rref.remote().parameter_rrefs().to_here()
205+
206+
# Create optimizers on remote workers
207+
opt1_rref = rpc.remote(
208+
"worker1",
209+
create_optimizer_for_remote_params,
210+
args=("worker1", p1_param_rrefs)
211+
)
212+
opt2_rref = rpc.remote(
213+
"worker2",
214+
create_optimizer_for_remote_params,
215+
args=("worker2", p2_param_rrefs)
197216
)
198217

199218
one_hot_indices = torch.LongTensor(batch_size) \
@@ -213,7 +232,14 @@ def run_master(split_size):
213232
with dist_autograd.context() as context_id:
214233
outputs = model(inputs)
215234
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
216-
opt.step(context_id)
235+
236+
# Step both optimizers
237+
opt1_rref.rpc_sync().step()
238+
opt2_rref.rpc_sync().step()
239+
240+
# Zero gradients
241+
opt1_rref.rpc_sync().zero_grad()
242+
opt2_rref.rpc_sync().zero_grad()
217243

218244

219245
def run_worker(rank, world_size, num_split):
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
torch==1.9.0
2-
torchvision==0.7.0
1+
torch
2+
torchvision

0 commit comments

Comments
 (0)