11from argparse import ArgumentParser
22
3+ import evosax .algorithms
34import gymnasium as gym
4- import jax
55import numpy as np
66from chex import Array
7- from evosax .algorithms import PGPE
87from flax import linen as nn
98
109from reinforced_lib import RLib
1110from reinforced_lib .agents .neuro import Evosax
1211from reinforced_lib .exts import GymnasiumVectorized
13- from reinforced_lib .logs import StdoutLogger , TensorboardLogger , WeightsAndBiasesLogger
12+ from reinforced_lib .logs import CsvLogger , StdoutLogger
1413
1514
1615class Network (nn .Module ):
@@ -20,17 +19,19 @@ def __call__(self, x: Array) -> Array:
2019 x = nn .tanh (x )
2120 x = nn .Dense (64 )(x )
2221 x = nn .tanh (x )
23- logits = nn .Dense (2 )(x )
24- action = jax . random . categorical ( self . make_rng ( 'rlib' ), logits )
22+ x = nn .Dense (1 )(x )
23+ action = 2 * nn . tanh ( x )
2524 return action
2625
2726
28- def run (num_epochs : int , population_size : int , seed : int ) -> None :
27+ def run (evo_alg : type , num_epochs : int , population_size : int , seed : int ) -> None :
2928 """
30- Run ``num_epochs `` cart-pole Gymnasium environments in parallel using an evolutionary strategy to optimize the policy.
29+ Run ``num_envs `` Pendulum Gymnasium environments in parallel using an evolutionary strategy to optimize the policy.
3130
3231 Parameters
3332 ----------
33+ evo_alg : type
34+ Evolutionary strategy to use (from evosax).
3435 num_epochs : int
3536 Number of simulation steps to perform.
3637 population_size : int
@@ -43,44 +44,46 @@ def run(num_epochs: int, population_size: int, seed: int) -> None:
4344 agent_type = Evosax ,
4445 agent_params = {
4546 'network' : Network (),
46- 'evo_strategy' : PGPE ,
47+ 'evo_strategy' : evo_alg ,
4748 'evo_strategy_default_params' : {'std_init' : 0.1 },
4849 'population_size' : population_size
4950 },
5051 ext_type = GymnasiumVectorized ,
51- ext_params = {'env_id' : 'CartPole-v1' , 'num_envs' : population_size },
52- logger_types = [StdoutLogger , TensorboardLogger , WeightsAndBiasesLogger ]
52+ ext_params = {'env_id' : 'Pendulum-v1' , 'num_envs' : population_size },
53+ logger_types = [CsvLogger , StdoutLogger ],
54+ logger_params = {'csv_path' : f'pendulum-{ evo_alg .__name__ } -evo-{ seed } .csv' }
5355 )
5456
5557 def make_env ():
56- return gym .make ('CartPole -v1' , render_mode = 'no ' )
58+ return gym .make ('Pendulum -v1' )
5759
58- for step in range (num_epochs ):
59- env = gym .vector .SyncVectorEnv ([make_env for _ in range (population_size )])
60+ env = gym .vector .SyncVectorEnv ([make_env for _ in range (population_size )])
6061
61- _ , _ = env .reset (seed = seed + step )
62+ for epoch in range (num_epochs ):
63+ _ , _ = env .reset (seed = seed + epoch )
6264 actions = env .action_space .sample ()
65+ return_pop = np .zeros (population_size , dtype = float )
6366
64- terminal = np .array ([False ] * population_size )
65- max_epoch_len = 0
66-
67- while not np .all (terminal ):
67+ for _ in range (env .envs [0 ].spec .max_episode_steps ):
6868 env_states = env .step (np .asarray (actions ))
6969 actions = rl .sample (* env_states )
70+ return_pop += env_states [1 ]
7071
71- terminal = terminal | env_states [2 ] | env_states [3 ]
72- max_epoch_len += 1
73-
74- rl .log ('max_epoch_len' , max_epoch_len )
72+ rl .log ('mean_return' , return_pop .mean ())
73+ rl .log ('max_return' , return_pop .max ())
74+ rl .log ('epoch' , epoch + 1 )
7575
7676
7777if __name__ == '__main__' :
7878 args = ArgumentParser ()
7979
80+ args .add_argument ('--evo_alg' , type = str , required = True )
8081 args .add_argument ('--num_epochs' , default = 300 , type = int )
8182 args .add_argument ('--population_size' , default = 64 , type = int )
8283 args .add_argument ('--seed' , default = 42 , type = int )
8384
8485 args = args .parse_args ()
8586
86- run (** (vars (args )))
87+ args = vars (args )
88+ args ['evo_alg' ] = getattr (evosax .algorithms , args ['evo_alg' ])
89+ run (** args )
0 commit comments