Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions distributed/rpc/pipeline/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import threading
import time
import warnings
from functools import wraps

import torch
Expand All @@ -9,11 +10,18 @@
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.optim as optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef

from torchvision.models.resnet import Bottleneck

# Suppress warnings that can't be fixed from user code
warnings.filterwarnings("ignore",
message="You are using a Backend .* as a ProcessGroup. This usage is deprecated",
category=UserWarning)
warnings.filterwarnings("ignore",
message="networkx backend defined more than once: nx-loopback",
category=RuntimeWarning)


#########################################################
# Define Model Parallel ResNet50 #
Expand Down Expand Up @@ -185,15 +193,34 @@ def parameter_rrefs(self):
image_h = 128


def create_optimizer_for_remote_params(worker_name, param_rrefs, lr=0.05):
"""Create torch.compiled optimizers on each worker"""
params = [p.to_here() for p in param_rrefs]
opt = optim.SGD(params, lr=lr)
opt.step = torch.compile(opt.step)
return opt


def run_master(split_size):

# put the two model parts on worker1 and worker2 respectively
model = DistResNet50(split_size, ["worker1", "worker2"])
loss_fn = nn.MSELoss()
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,

# Get parameter RRefs for each model shard
p1_param_rrefs = model.p1_rref.remote().parameter_rrefs().to_here()
p2_param_rrefs = model.p2_rref.remote().parameter_rrefs().to_here()

# Create optimizers on remote workers
opt1_rref = rpc.remote(
"worker1",
create_optimizer_for_remote_params,
args=("worker1", p1_param_rrefs)
)
opt2_rref = rpc.remote(
"worker2",
create_optimizer_for_remote_params,
args=("worker2", p2_param_rrefs)
)

one_hot_indices = torch.LongTensor(batch_size) \
Expand All @@ -213,7 +240,12 @@ def run_master(split_size):
with dist_autograd.context() as context_id:
outputs = model(inputs)
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
opt.step(context_id)

opt1_rref.rpc_sync().step()
opt2_rref.rpc_sync().step()

opt1_rref.rpc_sync().zero_grad()
opt2_rref.rpc_sync().zero_grad()


def run_worker(rank, world_size, num_split):
Expand Down Expand Up @@ -245,6 +277,9 @@ def run_worker(rank, world_size, num_split):


if __name__=="__main__":
# Suppress torch compile profiler warnings
os.environ['TORCH_LOGS'] = '-dynamo'

world_size = 3
for num_split in [1, 2, 4, 8]:
tik = time.time()
Expand Down
4 changes: 2 additions & 2 deletions distributed/rpc/pipeline/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch==1.9.0
torchvision==0.7.0
torch
torchvision