Skip to content

Commit f979909

Browse files
committed
update
1 parent fce10a5 commit f979909

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

distributed/rpc/pipeline/main.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,9 @@ def parameter_rrefs(self):
194194

195195

196196
def create_optimizer_for_remote_params(worker_name, param_rrefs, lr=0.05):
197-
"""Create optimizer on remote worker for given parameters"""
197+
"""Create torch.compiled optimizers on each worker"""
198198
params = [p.to_here() for p in param_rrefs]
199199
opt = optim.SGD(params, lr=lr)
200-
# Use torch.compile to optimize the optimizer step
201200
opt.step = torch.compile(opt.step)
202201
return opt
203202

@@ -242,11 +241,9 @@ def run_master(split_size):
242241
outputs = model(inputs)
243242
dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
244243

245-
# Step both optimizers
246244
opt1_rref.rpc_sync().step()
247245
opt2_rref.rpc_sync().step()
248246

249-
# Zero gradients
250247
opt1_rref.rpc_sync().zero_grad()
251248
opt2_rref.rpc_sync().zero_grad()
252249

0 commit comments

Comments
 (0)