Skip to content

Commit 1a24d26

Browse files
committed
Fix More RPC examples
1 parent f1723eb commit 1a24d26

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

distributed/rpc/batch/parameter_server.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import os
22
import threading
33
from datetime import datetime
4+
import warnings
45

56
import torch
7+
import torch.distributed as dist
68
import torch.distributed.rpc as rpc
79
import torch.multiprocessing as mp
810
import torch.nn as nn
911
from torch import optim
1012

1113
import torchvision
1214

15+
# Suppress deprecated ProcessGroup warning
16+
warnings.filterwarnings("ignore", message="You are using a Backend.*ProcessGroup")
17+
1318

1419
batch_size = 20
1520
image_w = 64
@@ -114,9 +119,17 @@ def run_ps(trainers):
114119
def run(rank, world_size):
115120
os.environ['MASTER_ADDR'] = 'localhost'
116121
os.environ['MASTER_PORT'] = '29500'
122+
123+
# Initialize the process group first
124+
dist.init_process_group(
125+
backend="gloo",
126+
rank=rank,
127+
world_size=world_size
128+
)
129+
117130
options=rpc.TensorPipeRpcBackendOptions(
118131
num_worker_threads=16,
119-
rpc_timeout=0 # infinite timeout
132+
rpc_timeout=60 # 60 second timeout instead of infinite
120133
)
121134
if rank != 0:
122135
rpc.init_rpc(
@@ -137,6 +150,7 @@ def run(rank, world_size):
137150

138151
# block until all rpcs finish
139152
rpc.shutdown()
153+
dist.destroy_process_group()
140154

141155

142156
if __name__=="__main__":

distributed/rpc/batch/reinforce.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import os
44
import threading
55
import time
6+
import warnings
7+
8+
# Suppress deprecated ProcessGroup warning
9+
warnings.filterwarnings("ignore", message="You are using a Backend.*ProcessGroup")
610

711
import torch
812
import torch.distributed.rpc as rpc
@@ -26,6 +30,8 @@
2630
help='random seed (default: 543)')
2731
parser.add_argument('--num-episode', type=int, default=10, metavar='E',
2832
help='number of episodes (default: 10)')
33+
parser.add_argument('--max-world-size', type=int, default=3, metavar='W',
34+
help='maximum world size to test (default: 3)')
2935
args = parser.parse_args()
3036

3137
torch.manual_seed(args.seed)
@@ -79,7 +85,8 @@ def run_episode(self, agent_rref, n_steps):
7985
agent_rref (RRef): an RRef referencing the agent object.
8086
n_steps (int): number of steps in this episode
8187
"""
82-
state, ep_reward = self.env.reset(), NUM_STEPS
88+
state, _ = self.env.reset()
89+
ep_reward = NUM_STEPS
8390
rewards = torch.zeros(n_steps)
8491
start_step = 0
8592
for step in range(n_steps):
@@ -101,7 +108,7 @@ def run_episode(self, agent_rref, n_steps):
101108
for i in range(curr_rewards.numel() -1, -1, -1):
102109
R = curr_rewards[i] + args.gamma * R
103110
curr_rewards[i] = R
104-
state = self.env.reset()
111+
state, _ = self.env.reset()
105112
if start_step == 0:
106113
ep_reward = min(ep_reward, step - start_step + 1)
107114
start_step = step + 1
@@ -235,7 +242,7 @@ def run_worker(rank, world_size, n_episode, batch, print_log=True):
235242

236243

237244
def main():
238-
for world_size in range(2, 12):
245+
for world_size in range(2, args.max_world_size):
239246
delays = []
240247
for batch in [True, False]:
241248
tik = time.time()
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch==2.2.0
2-
torchvision==0.7.0
1+
torch
2+
torchvision
33
numpy
44
gymnasium

0 commit comments

Comments
 (0)