9
9
import torch .distributed .rpc as rpc
10
10
import torch .multiprocessing as mp
11
11
import torch .optim as optim
12
- from torch .distributed .optim import DistributedOptimizer
13
12
from torch .distributed .rpc import RRef
14
13
15
14
from torchvision .models .resnet import Bottleneck
@@ -185,15 +184,35 @@ def parameter_rrefs(self):
185
184
image_h = 128
186
185
187
186
187
+ def create_optimizer_for_remote_params (worker_name , param_rrefs , lr = 0.05 ):
188
+ """Create optimizer on remote worker for given parameters"""
189
+ params = [p .to_here () for p in param_rrefs ]
190
+ opt = optim .SGD (params , lr = lr )
191
+ # Use torch.compile to optimize the optimizer step
192
+ opt .step = torch .compile (opt .step )
193
+ return opt
194
+
195
+
188
196
def run_master (split_size ):
189
197
190
198
# put the two model parts on worker1 and worker2 respectively
191
199
model = DistResNet50 (split_size , ["worker1" , "worker2" ])
192
200
loss_fn = nn .MSELoss ()
193
- opt = DistributedOptimizer (
194
- optim .SGD ,
195
- model .parameter_rrefs (),
196
- lr = 0.05 ,
201
+
202
+ # Get parameter RRefs for each model shard
203
+ p1_param_rrefs = model .p1_rref .remote ().parameter_rrefs ().to_here ()
204
+ p2_param_rrefs = model .p2_rref .remote ().parameter_rrefs ().to_here ()
205
+
206
+ # Create optimizers on remote workers
207
+ opt1_rref = rpc .remote (
208
+ "worker1" ,
209
+ create_optimizer_for_remote_params ,
210
+ args = ("worker1" , p1_param_rrefs )
211
+ )
212
+ opt2_rref = rpc .remote (
213
+ "worker2" ,
214
+ create_optimizer_for_remote_params ,
215
+ args = ("worker2" , p2_param_rrefs )
197
216
)
198
217
199
218
one_hot_indices = torch .LongTensor (batch_size ) \
@@ -213,7 +232,14 @@ def run_master(split_size):
213
232
with dist_autograd .context () as context_id :
214
233
outputs = model (inputs )
215
234
dist_autograd .backward (context_id , [loss_fn (outputs , labels )])
216
- opt .step (context_id )
235
+
236
+ # Step both optimizers
237
+ opt1_rref .rpc_sync ().step ()
238
+ opt2_rref .rpc_sync ().step ()
239
+
240
+ # Zero gradients
241
+ opt1_rref .rpc_sync ().zero_grad ()
242
+ opt2_rref .rpc_sync ().zero_grad ()
217
243
218
244
219
245
def run_worker (rank , world_size , num_split ):
0 commit comments