|
| 1 | +import time |
1 | 2 | from argparse import ArgumentParser |
2 | 3 |
|
3 | 4 | import gymnasium as gym |
|
9 | 10 | from reinforced_lib import RLib |
10 | 11 | from reinforced_lib.agents.deep import PPODiscrete |
11 | 12 | from reinforced_lib.exts import GymnasiumVectorized |
12 | | -from reinforced_lib.logs import StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger |
| 13 | +from reinforced_lib.logs import CsvLogger, StdoutLogger |
13 | 14 |
|
14 | 15 |
|
15 | 16 | class ActionNetwork(nn.Module): |
@@ -75,47 +76,43 @@ def run(num_steps: int, num_envs: int, seed: int) -> None: |
75 | 76 | 'value_coef': 0.5, |
76 | 77 | 'rollout_length': 32, |
77 | 78 | 'num_envs': num_envs, |
78 | | - 'batch_size': 512, |
| 79 | + 'batch_size': (num_envs * 32) // 4, |
79 | 80 | 'num_epochs': 4 |
80 | 81 | }, |
81 | 82 | ext_type=GymnasiumVectorized, |
82 | 83 | ext_params={'env_id': 'CartPole-v1', 'num_envs': num_envs}, |
83 | | - logger_types=[StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger] |
| 84 | + logger_types=[StdoutLogger, CsvLogger], |
| 85 | + logger_params={'csv_path': f'cartpole-ppo-{num_envs}-envs-{seed}.csv'} |
84 | 86 | ) |
85 | 87 |
|
86 | 88 | def make_env(): |
87 | 89 | return gym.make('CartPole-v1', render_mode='no') |
88 | 90 |
|
89 | | - step = 0 |
| 91 | + env = gym.vector.SyncVectorEnv([make_env for _ in range(num_envs)]) |
| 92 | + _, _ = env.reset(seed=seed) |
90 | 93 |
|
91 | | - while step < num_steps: |
92 | | - env = gym.vector.SyncVectorEnv([make_env for _ in range(num_envs)]) |
93 | | - |
94 | | - _, _ = env.reset(seed=seed + step) |
95 | | - actions = env.action_space.sample() |
96 | | - |
97 | | - terminal = np.array([False] * num_envs) |
98 | | - max_epoch_len, min_epoch_len = 0, 0 |
| 94 | + actions = env.action_space.sample() |
| 95 | + return_0, step = 0, 0 |
| 96 | + start_time = time.perf_counter() |
99 | 97 |
|
100 | | - while not np.all(terminal): |
101 | | - env_states = env.step(np.asarray(actions)) |
102 | | - actions = rl.sample(*env_states) |
103 | | - |
104 | | - terminal = terminal | env_states[2] | env_states[3] |
105 | | - max_epoch_len += 1 |
| 98 | + while step < num_steps: |
| 99 | + env_states = env.step(np.asarray(actions)) |
| 100 | + actions = rl.sample(*env_states) |
106 | 101 |
|
107 | | - if not np.any(terminal): |
108 | | - min_epoch_len += 1 |
| 102 | + return_0 += env_states[1][0] |
| 103 | + step += num_envs |
109 | 104 |
|
110 | | - rl.log('max_epoch_len', max_epoch_len) |
111 | | - rl.log('min_epoch_len', min_epoch_len) |
112 | | - step += max_epoch_len * num_envs |
| 105 | + if env_states[2][0] or env_states[3][0]: |
| 106 | + rl.log('return', return_0) |
| 107 | + rl.log('steps', step) |
| 108 | + rl.log('time', time.perf_counter() - start_time) |
| 109 | + return_0 = 0 |
113 | 110 |
|
114 | 111 |
|
115 | 112 | if __name__ == '__main__': |
116 | 113 | args = ArgumentParser() |
117 | 114 |
|
118 | | - args.add_argument('--num_steps', default=int(1e6), type=int) |
| 115 | + args.add_argument('--num_steps', default=int(1e7), type=int) |
119 | 116 | args.add_argument('--num_envs', default=64, type=int) |
120 | 117 | args.add_argument('--seed', default=42, type=int) |
121 | 118 |
|
|
0 commit comments