Skip to content

Commit 093a141

Browse files
committed
Configurable default params in Evosax agent
1 parent 5769e4e commit 093a141

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

examples/cart-pole-evo/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def run(num_epochs: int, population_size: int, seed: int) -> None:
4444
agent_params={
4545
'network': Network(),
4646
'evo_strategy': PGPE,
47+
'evo_strategy_default_params': {'std_init': 0.1},
4748
'population_size': population_size
4849
},
4950
ext_type=GymnasiumVectorized,

reinforced_lib/agents/neuro/evosax.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ class Evosax(BaseAgent):
6464
Shape of the observation space.
6565
act_space_shape : Shape, default=(1,)
6666
Shape of the action space. For discrete action spaces, use (1,).
67-
evo_strategy_params : dict, default=None
68-
Parameters for the evolution strategy.
67+
evo_strategy_kwargs : dict, default=None
68+
Parameters for the evolution strategy initialization. The population size and initial solution are set automatically.
69+
evo_strategy_default_params : dict, default=None
70+
Custom default parameters for the evolution strategy. If None, the default parameters are used.
6971
num_eval_steps : int, default=None
7072
Number of evaluation steps. If None, the evaluation runs until all episodes end.
7173
@@ -81,12 +83,16 @@ def __init__(
8183
population_size: int,
8284
obs_space_shape: Shape,
8385
act_space_shape: Shape,
84-
evo_strategy_params: dict = None,
86+
evo_strategy_kwargs: dict = None,
87+
evo_strategy_default_params: dict = None,
8588
num_eval_steps: int = None
8689
) -> None:
8790

88-
if evo_strategy_params is None:
89-
evo_strategy_params = {}
91+
if evo_strategy_kwargs is None:
92+
evo_strategy_kwargs = {}
93+
94+
if evo_strategy_default_params is None:
95+
evo_strategy_default_params = {}
9096

9197
self.obs_space_shape = obs_space_shape if jnp.ndim(obs_space_shape) > 0 else (obs_space_shape,)
9298
self.act_space_shape = act_space_shape if jnp.ndim(act_space_shape) > 0 else (act_space_shape,)
@@ -96,15 +102,16 @@ def __init__(
96102
variables = network.init(jax.random.key(0), x_dummy)
97103
num_params, params_format_fn = self.get_params_format_fn(variables)
98104

99-
evo_strategy_params['population_size'] = population_size
100-
evo_strategy_params['solution'] = jnp.zeros(num_params)
101-
evo_strategy = evo_strategy(**evo_strategy_params)
105+
evo_strategy_kwargs['population_size'] = population_size
106+
evo_strategy_kwargs['solution'] = jnp.zeros(num_params)
107+
evo_strategy = evo_strategy(**evo_strategy_kwargs)
102108

103109
self.init = jax.jit(partial(
104110
self.init,
105111
population_size=population_size,
106112
variables=variables,
107-
evo_strategy=evo_strategy
113+
evo_strategy=evo_strategy,
114+
evo_strategy_default_params=evo_strategy_default_params
108115
))
109116
self.update = jax.jit(partial(
110117
self.update,
@@ -168,7 +175,8 @@ def init(
168175
key: PRNGKey,
169176
population_size: int,
170177
variables: dict,
171-
evo_strategy: EvolutionaryAlgorithm
178+
evo_strategy: EvolutionaryAlgorithm,
179+
evo_strategy_default_params: dict
172180
) -> EvosaxState:
173181
r"""
174182
Initializes the evolution strategy state and the population. The fitness values, step counter,
@@ -184,6 +192,8 @@ def init(
184192
The initialized parameters of the agent network.
185193
evo_strategy : EvolutionaryAlgorithm
186194
Initialized evosax evolution strategy.
195+
evo_strategy_default_params : dict
196+
Custom default parameters for the evolution strategy.
187197
188198
Returns
189199
-------
@@ -192,6 +202,8 @@ def init(
192202
"""
193203

194204
es_params = evo_strategy.default_params
205+
es_params = es_params.replace(**evo_strategy_default_params)
206+
195207
es_state = evo_strategy.init(key, variables, es_params)
196208
population, es_state = evo_strategy.ask(key, es_state, es_params)
197209

0 commit comments

Comments
 (0)