1
1
import os
2
2
import threading
3
3
import time
4
+ import warnings
4
5
from functools import wraps
5
6
6
7
import torch
9
10
import torch .distributed .rpc as rpc
10
11
import torch .multiprocessing as mp
11
12
import torch .optim as optim
12
- from torch .distributed .optim import DistributedOptimizer
13
13
from torch .distributed .rpc import RRef
14
14
15
15
from torchvision .models .resnet import Bottleneck
16
16
17
+ # Suppress warnings that can't be fixed from user code
18
+ warnings .filterwarnings ("ignore" ,
19
+ message = "You are using a Backend .* as a ProcessGroup. This usage is deprecated" ,
20
+ category = UserWarning )
21
+ warnings .filterwarnings ("ignore" ,
22
+ message = "networkx backend defined more than once: nx-loopback" ,
23
+ category = RuntimeWarning )
24
+
17
25
18
26
#########################################################
19
27
# Define Model Parallel ResNet50 #
@@ -185,15 +193,34 @@ def parameter_rrefs(self):
185
193
image_h = 128
186
194
187
195
196
+ def create_optimizer_for_remote_params (worker_name , param_rrefs , lr = 0.05 ):
197
+ """Create torch.compiled optimizers on each worker"""
198
+ params = [p .to_here () for p in param_rrefs ]
199
+ opt = optim .SGD (params , lr = lr )
200
+ opt .step = torch .compile (opt .step )
201
+ return opt
202
+
203
+
188
204
def run_master (split_size ):
189
205
190
206
# put the two model parts on worker1 and worker2 respectively
191
207
model = DistResNet50 (split_size , ["worker1" , "worker2" ])
192
208
loss_fn = nn .MSELoss ()
193
- opt = DistributedOptimizer (
194
- optim .SGD ,
195
- model .parameter_rrefs (),
196
- lr = 0.05 ,
209
+
210
+ # Get parameter RRefs for each model shard
211
+ p1_param_rrefs = model .p1_rref .remote ().parameter_rrefs ().to_here ()
212
+ p2_param_rrefs = model .p2_rref .remote ().parameter_rrefs ().to_here ()
213
+
214
+ # Create optimizers on remote workers
215
+ opt1_rref = rpc .remote (
216
+ "worker1" ,
217
+ create_optimizer_for_remote_params ,
218
+ args = ("worker1" , p1_param_rrefs )
219
+ )
220
+ opt2_rref = rpc .remote (
221
+ "worker2" ,
222
+ create_optimizer_for_remote_params ,
223
+ args = ("worker2" , p2_param_rrefs )
197
224
)
198
225
199
226
one_hot_indices = torch .LongTensor (batch_size ) \
@@ -213,7 +240,12 @@ def run_master(split_size):
213
240
with dist_autograd .context () as context_id :
214
241
outputs = model (inputs )
215
242
dist_autograd .backward (context_id , [loss_fn (outputs , labels )])
216
- opt .step (context_id )
243
+
244
+ opt1_rref .rpc_sync ().step ()
245
+ opt2_rref .rpc_sync ().step ()
246
+
247
+ opt1_rref .rpc_sync ().zero_grad ()
248
+ opt2_rref .rpc_sync ().zero_grad ()
217
249
218
250
219
251
def run_worker (rank , world_size , num_split ):
@@ -245,6 +277,9 @@ def run_worker(rank, world_size, num_split):
245
277
246
278
247
279
if __name__ == "__main__" :
280
+ # Suppress torch compile profiler warnings
281
+ os .environ ['TORCH_LOGS' ] = '-dynamo'
282
+
248
283
world_size = 3
249
284
for num_split in [1 , 2 , 4 , 8 ]:
250
285
tik = time .time ()
0 commit comments