66import optax
77from chex import Array
88from flax import linen as nn
9- from tqdm import tqdm
109
1110from reinforced_lib import RLib
1211from reinforced_lib .agents .deep import PPODiscrete
@@ -48,14 +47,14 @@ def __call__(self, x: Array) -> tuple[Array, Array]:
4847 return logits , values
4948
5049
51- def run (num_steps : int , num_envs : int , seed : int ) -> None :
50+ def run (time_limit : float , num_envs : int , seed : int ) -> None :
5251 """
5352 Run ``num_steps`` cart-pole Gymnasium steps.
5453
5554 Parameters
5655 ----------
57- num_steps : int
58- Number of simulation steps to perform .
56+ time_limit : float
57+ Maximum time (in seconds) to run the experiment .
5958 num_envs : int
6059 Number of parallel environments to use.
6160 seed : int
@@ -96,15 +95,12 @@ def make_env():
9695 return_0 , step = 0 , 0
9796 start_time = time .perf_counter ()
9897
99- pbar = tqdm (total = num_steps )
100-
101- while step < num_steps :
98+ while time .perf_counter () - start_time < time_limit :
10299 env_states = env .step (np .asarray (actions ))
103100 actions = rl .sample (* env_states )
104101
105102 return_0 += env_states [1 ][0 ]
106103 step += num_envs
107- pbar .update (num_envs )
108104
109105 if env_states [2 ][0 ] or env_states [3 ][0 ]:
110106 rl .log ('return' , return_0 )
@@ -116,7 +112,7 @@ def make_env():
116112if __name__ == '__main__' :
117113 args = ArgumentParser ()
118114
119- args .add_argument ('--num_steps ' , default = int ( 1e7 ) , type = int )
115+ args .add_argument ('--time_limit ' , default = 120 , type = float )
120116 args .add_argument ('--num_envs' , default = 64 , type = int )
121117 args .add_argument ('--seed' , default = 42 , type = int )
122118
0 commit comments