Skip to content

Commit cb98907

Browse files
committed
Change example to time limit
1 parent afe9eb1 commit cb98907

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

examples/cart-pole-vectorized/main.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import optax
77
from chex import Array
88
from flax import linen as nn
9-
from tqdm import tqdm
109

1110
from reinforced_lib import RLib
1211
from reinforced_lib.agents.deep import PPODiscrete
@@ -48,14 +47,14 @@ def __call__(self, x: Array) -> tuple[Array, Array]:
4847
return logits, values
4948

5049

51-
def run(num_steps: int, num_envs: int, seed: int) -> None:
50+
def run(time_limit: float, num_envs: int, seed: int) -> None:
5251
"""
5352
Run ``num_steps`` cart-pole Gymnasium steps.
5453
5554
Parameters
5655
----------
57-
num_steps : int
58-
Number of simulation steps to perform.
56+
time_limit : float
57+
Maximum time (in seconds) to run the experiment.
5958
num_envs : int
6059
Number of parallel environments to use.
6160
seed : int
@@ -96,15 +95,12 @@ def make_env():
9695
return_0, step = 0, 0
9796
start_time = time.perf_counter()
9897

99-
pbar = tqdm(total=num_steps)
100-
101-
while step < num_steps:
98+
while time.perf_counter() - start_time < time_limit:
10299
env_states = env.step(np.asarray(actions))
103100
actions = rl.sample(*env_states)
104101

105102
return_0 += env_states[1][0]
106103
step += num_envs
107-
pbar.update(num_envs)
108104

109105
if env_states[2][0] or env_states[3][0]:
110106
rl.log('return', return_0)
@@ -116,7 +112,7 @@ def make_env():
116112
if __name__ == '__main__':
117113
args = ArgumentParser()
118114

119-
args.add_argument('--num_steps', default=int(1e7), type=int)
115+
args.add_argument('--time_limit', default=120, type=float)
120116
args.add_argument('--num_envs', default=64, type=int)
121117
args.add_argument('--seed', default=42, type=int)
122118

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,5 @@ full = [
5050
# "ns3-ai==1.0.2",
5151
"pygame~=2.6.1",
5252
"seaborn~=0.13.2",
53-
"tensorflow~=2.19.1",
54-
"tqdm~=4.67.1"
53+
"tensorflow~=2.19.1"
5554
]

0 commit comments

Comments
 (0)