3
3
import os
4
4
import threading
5
5
import time
6
+ import warnings
7
+
8
+ # Suppress deprecated ProcessGroup warning
9
+ warnings .filterwarnings ("ignore" , message = "You are using a Backend.*ProcessGroup" )
6
10
7
11
import torch
8
12
import torch .distributed .rpc as rpc
26
30
help = 'random seed (default: 543)' )
27
31
parser .add_argument ('--num-episode' , type = int , default = 10 , metavar = 'E' ,
28
32
help = 'number of episodes (default: 10)' )
33
+ parser .add_argument ('--max-world-size' , type = int , default = 3 , metavar = 'W' ,
34
+ help = 'maximum world size to test (default: 3)' )
29
35
args = parser .parse_args ()
30
36
31
37
torch .manual_seed (args .seed )
@@ -79,7 +85,8 @@ def run_episode(self, agent_rref, n_steps):
79
85
agent_rref (RRef): an RRef referencing the agent object.
80
86
n_steps (int): number of steps in this episode
81
87
"""
82
- state , ep_reward = self .env .reset (), NUM_STEPS
88
+ state , _ = self .env .reset ()
89
+ ep_reward = NUM_STEPS
83
90
rewards = torch .zeros (n_steps )
84
91
start_step = 0
85
92
for step in range (n_steps ):
@@ -101,7 +108,7 @@ def run_episode(self, agent_rref, n_steps):
101
108
for i in range (curr_rewards .numel () - 1 , - 1 , - 1 ):
102
109
R = curr_rewards [i ] + args .gamma * R
103
110
curr_rewards [i ] = R
104
- state = self .env .reset ()
111
+ state , _ = self .env .reset ()
105
112
if start_step == 0 :
106
113
ep_reward = min (ep_reward , step - start_step + 1 )
107
114
start_step = step + 1
@@ -235,7 +242,7 @@ def run_worker(rank, world_size, n_episode, batch, print_log=True):
235
242
236
243
237
244
def main ():
238
- for world_size in range (2 , 12 ):
245
+ for world_size in range (2 , args . max_world_size ):
239
246
delays = []
240
247
for batch in [True , False ]:
241
248
tik = time .time ()
0 commit comments