Skip to content

Commit 5ca3880

Browse files
committed
Update evolutionary algorithms example
1 parent 4cd5400 commit 5ca3880

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

examples/cart-pole-vectorized/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def run(time_limit: float, num_envs: int, seed: int) -> None:
8282
},
8383
ext_type=GymnasiumVectorized,
8484
ext_params={'env_id': 'CartPole-v1', 'num_envs': num_envs},
85-
logger_types=[CsvLogger],
85+
logger_types=CsvLogger,
8686
logger_params={'csv_path': f'cartpole-ppo-{num_envs}-envs-{seed}.csv'}
8787
)
8888

examples/pendulum-evo/main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import evosax.algorithms
44
import gymnasium as gym
55
import numpy as np
6+
import optax
67
from chex import Array
78
from flax import linen as nn
89

910
from reinforced_lib import RLib
1011
from reinforced_lib.agents.neuro import Evosax
1112
from reinforced_lib.exts import GymnasiumVectorized
12-
from reinforced_lib.logs import CsvLogger, StdoutLogger
13+
from reinforced_lib.logs import CsvLogger
1314

1415

1516
class Network(nn.Module):
@@ -40,17 +41,23 @@ def run(evo_alg: type, num_epochs: int, population_size: int, seed: int) -> None
4041
Integer used as the random key.
4142
"""
4243

44+
if isinstance(evo_alg, evosax.algorithms.SimpleES):
45+
evo_kwargs = {'optimizer': optax.adam(0.03)}
46+
else:
47+
evo_kwargs = {}
48+
4349
rl = RLib(
4450
agent_type=Evosax,
4551
agent_params={
4652
'network': Network(),
4753
'evo_strategy': evo_alg,
48-
'evo_strategy_default_params': {'std_init': 0.1},
54+
'evo_strategy_kwargs': evo_kwargs,
55+
'evo_strategy_default_params': {'std_init': 0.05},
4956
'population_size': population_size
5057
},
5158
ext_type=GymnasiumVectorized,
5259
ext_params={'env_id': 'Pendulum-v1', 'num_envs': population_size},
53-
logger_types=[CsvLogger, StdoutLogger],
60+
logger_types=CsvLogger,
5461
logger_params={'csv_path': f'pendulum-{evo_alg.__name__}-evo-{seed}.csv'}
5562
)
5663

@@ -78,7 +85,7 @@ def make_env():
7885
args = ArgumentParser()
7986

8087
args.add_argument('--evo_alg', type=str, required=True)
81-
args.add_argument('--num_epochs', default=300, type=int)
88+
args.add_argument('--num_epochs', default=500, type=int)
8289
args.add_argument('--population_size', default=64, type=int)
8390
args.add_argument('--seed', default=42, type=int)
8491

examples/pendulum-evo/run_all.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#!/bin/bash
22

3-
evo_algs=("CMA_ES" "PGPE" "SimpleGA")
3+
evo_algs=("CMA_ES" "PGPE" "SimpleES")
44
seeds=(1 2 3 4 5 6 7 8 9 10)
55

66
for alg in "${evo_algs[@]}"; do
77
for s in "${seeds[@]}"; do
8-
echo "Running with $n environments and seed $s"
8+
echo "Running with algorithm $alg and seed $s"
99
python main.py --evo_alg $alg --seed $s
1010
done
1111
done

reinforced_lib/agents/neuro/evosax.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,22 @@ class EvosaxState(AgentState):
4444

4545
class Evosax(BaseAgent):
4646
r"""
47-
Evolution strategies (ES)-based agent using the evosax library [12]_. This implementation maintains a population
47+
Evolution strategies (ES)-based agent using the ``evosax`` library [12]_. This implementation maintains a population
4848
of candidate solutions (parameter vectors), evaluates them in parallel across environments, and updates the
4949
population by applying an evolutionary algorithm. Unlike gradient-based RL methods, this agent does not rely
5050
on backpropagation through the value or policy network. Instead, the network parameters are evolved using
5151
black-box optimization. This agent is suitable for environments with both discrete and continuous action spaces.
52-
The user is responsible for providing appropriate network output in the correct format (e.g., discrete actions
53-
should be sampled from logits with ``jax.random.categorical`` inside the network definition). Note that
54-
this agent does not discount future rewards, therefore, the fitness is computed as a simple sum of rewards
55-
obtained during the evaluation phase.
52+
53+
**Note!** The user is responsible for providing appropriate network output in the correct format (e.g., discrete
54+
actions should be sampled from logits with ``jax.random.categorical`` inside the network definition).
55+
56+
**Note!** This agent does not discount future rewards, therefore, the fitness is computed as a simple sum of
57+
rewards obtained during the evaluation phase.
58+
59+
**Note!** This agent is compatible only with distribution-based evolution strategies from the ``evosax`` library
60+
(see `this list <https://github.com/RobertTLange/evosax/tree/main/evosax/algorithms/distribution_based>`_ for
61+
available algorithms). Population-based methods (`listed here <https://github.com/RobertTLange/evosax/tree/main/evosax/algorithms/population_based>`_
62+
will be supported in future releases.
5663
5764
Parameters
5865
----------

0 commit comments

Comments
 (0)