From 89e121563601046e8a92a9a715280ee73e77803c Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 25 Aug 2025 12:10:45 -0700 Subject: [PATCH 1/3] Modernize RPC example --- distributed/rpc/pipeline/main.py | 38 +++++++++++++++++++---- distributed/rpc/pipeline/requirements.txt | 4 +-- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/distributed/rpc/pipeline/main.py b/distributed/rpc/pipeline/main.py index 5d4e73b83d..4afb4fda2a 100644 --- a/distributed/rpc/pipeline/main.py +++ b/distributed/rpc/pipeline/main.py @@ -9,7 +9,6 @@ import torch.distributed.rpc as rpc import torch.multiprocessing as mp import torch.optim as optim -from torch.distributed.optim import DistributedOptimizer from torch.distributed.rpc import RRef from torchvision.models.resnet import Bottleneck @@ -185,15 +184,35 @@ def parameter_rrefs(self): image_h = 128 +def create_optimizer_for_remote_params(worker_name, param_rrefs, lr=0.05): + """Create optimizer on remote worker for given parameters""" + params = [p.to_here() for p in param_rrefs] + opt = optim.SGD(params, lr=lr) + # Use torch.compile to optimize the optimizer step + opt.step = torch.compile(opt.step) + return opt + + def run_master(split_size): # put the two model parts on worker1 and worker2 respectively model = DistResNet50(split_size, ["worker1", "worker2"]) loss_fn = nn.MSELoss() - opt = DistributedOptimizer( - optim.SGD, - model.parameter_rrefs(), - lr=0.05, + + # Get parameter RRefs for each model shard + p1_param_rrefs = model.p1_rref.remote().parameter_rrefs().to_here() + p2_param_rrefs = model.p2_rref.remote().parameter_rrefs().to_here() + + # Create optimizers on remote workers + opt1_rref = rpc.remote( + "worker1", + create_optimizer_for_remote_params, + args=("worker1", p1_param_rrefs) + ) + opt2_rref = rpc.remote( + "worker2", + create_optimizer_for_remote_params, + args=("worker2", p2_param_rrefs) ) one_hot_indices = torch.LongTensor(batch_size) \ @@ -213,7 +232,14 @@ def run_master(split_size): with dist_autograd.context() as context_id: outputs = model(inputs) dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) - opt.step(context_id) + + # Step both optimizers + opt1_rref.rpc_sync().step() + opt2_rref.rpc_sync().step() + + # Zero gradients + opt1_rref.rpc_sync().zero_grad() + opt2_rref.rpc_sync().zero_grad() def run_worker(rank, world_size, num_split): diff --git a/distributed/rpc/pipeline/requirements.txt b/distributed/rpc/pipeline/requirements.txt index 22bc42cb99..37f700a78e 100644 --- a/distributed/rpc/pipeline/requirements.txt +++ b/distributed/rpc/pipeline/requirements.txt @@ -1,2 +1,2 @@ -torch==1.9.0 -torchvision==0.7.0 \ No newline at end of file +torch +torchvision \ No newline at end of file From fce10a508ccd0a255088cbbadfc23c39bee6060d Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 25 Aug 2025 12:22:28 -0700 Subject: [PATCH 2/3] update --- distributed/rpc/pipeline/main.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/distributed/rpc/pipeline/main.py b/distributed/rpc/pipeline/main.py index 4afb4fda2a..77007dd40e 100644 --- a/distributed/rpc/pipeline/main.py +++ b/distributed/rpc/pipeline/main.py @@ -1,6 +1,7 @@ import os import threading import time +import warnings from functools import wraps import torch @@ -13,6 +14,14 @@ from torchvision.models.resnet import Bottleneck +# Suppress warnings that can't be fixed from user code +warnings.filterwarnings("ignore", + message="You are using a Backend .* as a ProcessGroup. This usage is deprecated", + category=UserWarning) +warnings.filterwarnings("ignore", + message="networkx backend defined more than once: nx-loopback", + category=RuntimeWarning) + ######################################################### # Define Model Parallel ResNet50 # @@ -271,6 +280,9 @@ def run_worker(rank, world_size, num_split): if __name__=="__main__": + # Suppress torch compile profiler warnings + os.environ['TORCH_LOGS'] = '-dynamo' + world_size = 3 for num_split in [1, 2, 4, 8]: tik = time.time() From f9799096c2db2c3b8e895345f9a6b35a3c81e450 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 25 Aug 2025 12:24:51 -0700 Subject: [PATCH 3/3] update --- distributed/rpc/pipeline/main.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/distributed/rpc/pipeline/main.py b/distributed/rpc/pipeline/main.py index 77007dd40e..c9226f8763 100644 --- a/distributed/rpc/pipeline/main.py +++ b/distributed/rpc/pipeline/main.py @@ -194,10 +194,9 @@ def parameter_rrefs(self): def create_optimizer_for_remote_params(worker_name, param_rrefs, lr=0.05): - """Create optimizer on remote worker for given parameters""" + """Create torch.compiled optimizers on each worker""" params = [p.to_here() for p in param_rrefs] opt = optim.SGD(params, lr=lr) - # Use torch.compile to optimize the optimizer step opt.step = torch.compile(opt.step) return opt @@ -242,11 +241,9 @@ def run_master(split_size): outputs = model(inputs) dist_autograd.backward(context_id, [loss_fn(outputs, labels)]) - # Step both optimizers opt1_rref.rpc_sync().step() opt2_rref.rpc_sync().step() - # Zero gradients opt1_rref.rpc_sync().zero_grad() opt2_rref.rpc_sync().zero_grad()