File tree Expand file tree Collapse file tree 1 file changed +1
-4
lines changed Expand file tree Collapse file tree 1 file changed +1
-4
lines changed Original file line number Diff line number Diff line change @@ -194,10 +194,9 @@ def parameter_rrefs(self):
194
194
195
195
196
196
def create_optimizer_for_remote_params (worker_name , param_rrefs , lr = 0.05 ):
197
- """Create optimizer on remote worker for given parameters """
197
+ """Create torch.compiled optimizers on each worker """
198
198
params = [p .to_here () for p in param_rrefs ]
199
199
opt = optim .SGD (params , lr = lr )
200
- # Use torch.compile to optimize the optimizer step
201
200
opt .step = torch .compile (opt .step )
202
201
return opt
203
202
@@ -242,11 +241,9 @@ def run_master(split_size):
242
241
outputs = model (inputs )
243
242
dist_autograd .backward (context_id , [loss_fn (outputs , labels )])
244
243
245
- # Step both optimizers
246
244
opt1_rref .rpc_sync ().step ()
247
245
opt2_rref .rpc_sync ().step ()
248
246
249
- # Zero gradients
250
247
opt1_rref .rpc_sync ().zero_grad ()
251
248
opt2_rref .rpc_sync ().zero_grad ()
252
249
You can’t perform that action at this time.
0 commit comments