Skip to content

Commit 0f1af58

Browse files
committed
Update vectorized cart pole example
1 parent c2d82a8 commit 0f1af58

File tree

1 file changed

+21
-24
lines changed
  • examples/cart-pole-vectorized

1 file changed

+21
-24
lines changed

examples/cart-pole-vectorized/main.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from argparse import ArgumentParser
23

34
import gymnasium as gym
@@ -9,7 +10,7 @@
910
from reinforced_lib import RLib
1011
from reinforced_lib.agents.deep import PPODiscrete
1112
from reinforced_lib.exts import GymnasiumVectorized
12-
from reinforced_lib.logs import StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger
13+
from reinforced_lib.logs import CsvLogger, StdoutLogger
1314

1415

1516
class ActionNetwork(nn.Module):
@@ -75,47 +76,43 @@ def run(num_steps: int, num_envs: int, seed: int) -> None:
7576
'value_coef': 0.5,
7677
'rollout_length': 32,
7778
'num_envs': num_envs,
78-
'batch_size': 512,
79+
'batch_size': (num_envs * 32) // 4,
7980
'num_epochs': 4
8081
},
8182
ext_type=GymnasiumVectorized,
8283
ext_params={'env_id': 'CartPole-v1', 'num_envs': num_envs},
83-
logger_types=[StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger]
84+
logger_types=[StdoutLogger, CsvLogger],
85+
logger_params={'csv_path': f'cartpole-ppo-{num_envs}-envs-{seed}.csv'}
8486
)
8587

8688
def make_env():
8789
return gym.make('CartPole-v1', render_mode='no')
8890

89-
step = 0
91+
env = gym.vector.SyncVectorEnv([make_env for _ in range(num_envs)])
92+
_, _ = env.reset(seed=seed)
9093

91-
while step < num_steps:
92-
env = gym.vector.SyncVectorEnv([make_env for _ in range(num_envs)])
93-
94-
_, _ = env.reset(seed=seed + step)
95-
actions = env.action_space.sample()
96-
97-
terminal = np.array([False] * num_envs)
98-
max_epoch_len, min_epoch_len = 0, 0
94+
actions = env.action_space.sample()
95+
return_0, step = 0, 0
96+
start_time = time.perf_counter()
9997

100-
while not np.all(terminal):
101-
env_states = env.step(np.asarray(actions))
102-
actions = rl.sample(*env_states)
103-
104-
terminal = terminal | env_states[2] | env_states[3]
105-
max_epoch_len += 1
98+
while step < num_steps:
99+
env_states = env.step(np.asarray(actions))
100+
actions = rl.sample(*env_states)
106101

107-
if not np.any(terminal):
108-
min_epoch_len += 1
102+
return_0 += env_states[1][0]
103+
step += num_envs
109104

110-
rl.log('max_epoch_len', max_epoch_len)
111-
rl.log('min_epoch_len', min_epoch_len)
112-
step += max_epoch_len * num_envs
105+
if env_states[2][0] or env_states[3][0]:
106+
rl.log('return', return_0)
107+
rl.log('steps', step)
108+
rl.log('time', time.perf_counter() - start_time)
109+
return_0 = 0
113110

114111

115112
if __name__ == '__main__':
116113
args = ArgumentParser()
117114

118-
args.add_argument('--num_steps', default=int(1e6), type=int)
115+
args.add_argument('--num_steps', default=int(1e7), type=int)
119116
args.add_argument('--num_envs', default=64, type=int)
120117
args.add_argument('--seed', default=42, type=int)
121118

0 commit comments

Comments
 (0)