Skip to content

Commit b73695e

Browse files
committed
Update evolutionary algorithms example
1 parent cb98907 commit b73695e

File tree

4 files changed

+42
-27
lines changed

4 files changed

+42
-27
lines changed

examples/cart-pole-evo/main.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from argparse import ArgumentParser
22

3+
import evosax.algorithms
34
import gymnasium as gym
4-
import jax
55
import numpy as np
66
from chex import Array
7-
from evosax.algorithms import PGPE
87
from flax import linen as nn
98

109
from reinforced_lib import RLib
1110
from reinforced_lib.agents.neuro import Evosax
1211
from reinforced_lib.exts import GymnasiumVectorized
13-
from reinforced_lib.logs import StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger
12+
from reinforced_lib.logs import CsvLogger, StdoutLogger
1413

1514

1615
class Network(nn.Module):
@@ -20,17 +19,19 @@ def __call__(self, x: Array) -> Array:
2019
x = nn.tanh(x)
2120
x = nn.Dense(64)(x)
2221
x = nn.tanh(x)
23-
logits = nn.Dense(2)(x)
24-
action = jax.random.categorical(self.make_rng('rlib'), logits)
22+
x = nn.Dense(1)(x)
23+
action = 2 * nn.tanh(x)
2524
return action
2625

2726

28-
def run(num_epochs: int, population_size: int, seed: int) -> None:
27+
def run(evo_alg: type, num_epochs: int, population_size: int, seed: int) -> None:
2928
"""
30-
Run ``num_epochs`` cart-pole Gymnasium environments in parallel using an evolutionary strategy to optimize the policy.
29+
Run ``num_envs`` Pendulum Gymnasium environments in parallel using an evolutionary strategy to optimize the policy.
3130
3231
Parameters
3332
----------
33+
evo_alg : type
34+
Evolutionary strategy to use (from evosax).
3435
num_epochs : int
3536
Number of simulation steps to perform.
3637
population_size : int
@@ -43,44 +44,46 @@ def run(num_epochs: int, population_size: int, seed: int) -> None:
4344
agent_type=Evosax,
4445
agent_params={
4546
'network': Network(),
46-
'evo_strategy': PGPE,
47+
'evo_strategy': evo_alg,
4748
'evo_strategy_default_params': {'std_init': 0.1},
4849
'population_size': population_size
4950
},
5051
ext_type=GymnasiumVectorized,
51-
ext_params={'env_id': 'CartPole-v1', 'num_envs': population_size},
52-
logger_types=[StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger]
52+
ext_params={'env_id': 'Pendulum-v1', 'num_envs': population_size},
53+
logger_types=[CsvLogger, StdoutLogger],
54+
logger_params={'csv_path': f'pendulum-{evo_alg.__name__}-evo-{seed}.csv'}
5355
)
5456

5557
def make_env():
56-
return gym.make('CartPole-v1', render_mode='no')
58+
return gym.make('Pendulum-v1')
5759

58-
for step in range(num_epochs):
59-
env = gym.vector.SyncVectorEnv([make_env for _ in range(population_size)])
60+
env = gym.vector.SyncVectorEnv([make_env for _ in range(population_size)])
6061

61-
_, _ = env.reset(seed=seed + step)
62+
for epoch in range(num_epochs):
63+
_, _ = env.reset(seed=seed + epoch)
6264
actions = env.action_space.sample()
65+
return_pop = np.zeros(population_size, dtype=float)
6366

64-
terminal = np.array([False] * population_size)
65-
max_epoch_len = 0
66-
67-
while not np.all(terminal):
67+
for _ in range(env.envs[0].spec.max_episode_steps):
6868
env_states = env.step(np.asarray(actions))
6969
actions = rl.sample(*env_states)
70+
return_pop += env_states[1]
7071

71-
terminal = terminal | env_states[2] | env_states[3]
72-
max_epoch_len += 1
73-
74-
rl.log('max_epoch_len', max_epoch_len)
72+
rl.log('mean_return', return_pop.mean())
73+
rl.log('max_return', return_pop.max())
74+
rl.log('epoch', epoch + 1)
7575

7676

7777
if __name__ == '__main__':
7878
args = ArgumentParser()
7979

80+
args.add_argument('--evo_alg', type=str, required=True)
8081
args.add_argument('--num_epochs', default=300, type=int)
8182
args.add_argument('--population_size', default=64, type=int)
8283
args.add_argument('--seed', default=42, type=int)
8384

8485
args = args.parse_args()
8586

86-
run(**(vars(args)))
87+
args = vars(args)
88+
args['evo_alg'] = getattr(evosax.algorithms, args['evo_alg'])
89+
run(**args)

examples/cart-pole-evo/run_all.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
evo_algs=("CMA_ES" "PGPE" "SimpleGA")
4+
seeds=(1 2 3 4 5 6 7 8 9 10)
5+
6+
for alg in "${evo_algs[@]}"; do
7+
for s in "${seeds[@]}"; do
8+
echo "Running with $n environments and seed $s"
9+
python main.py --evo_alg $alg --seed $s
10+
done
11+
done

examples/cart-pole-vectorized/main.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def __call__(self, x: Array) -> tuple[Array, Array]:
4949

5050
def run(time_limit: float, num_envs: int, seed: int) -> None:
5151
"""
52-
Run ``num_steps`` cart-pole Gymnasium steps.
52+
Run ``num_envs`` CartPole Gymnasium environments in parallel using PPO to optimize the policy.
53+
The experiment runs for a maximum of ``time_limit`` seconds.
5354
5455
Parameters
5556
----------
@@ -86,7 +87,7 @@ def run(time_limit: float, num_envs: int, seed: int) -> None:
8687
)
8788

8889
def make_env():
89-
return gym.make('CartPole-v1', render_mode='no')
90+
return gym.make('CartPole-v1')
9091

9192
env = gym.vector.SyncVectorEnv([make_env for _ in range(num_envs)])
9293
_, _ = env.reset(seed=seed)
@@ -112,7 +113,7 @@ def make_env():
112113
if __name__ == '__main__':
113114
args = ArgumentParser()
114115

115-
args.add_argument('--time_limit', default=120, type=float)
116+
args.add_argument('--time_limit', default=85, type=float)
116117
args.add_argument('--num_envs', default=64, type=int)
117118
args.add_argument('--seed', default=42, type=int)
118119

examples/cart-pole-vectorized/run_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
num_envs=(2 4 8 16 32 64 128)
3+
num_envs=(2 4 8 16 32 64)
44
seeds=(1 2 3 4 5 6 7 8 9 10)
55

66
for n in "${num_envs[@]}"; do

0 commit comments

Comments
 (0)