@@ -64,8 +64,10 @@ class Evosax(BaseAgent):
6464 Shape of the observation space.
6565 act_space_shape : Shape, default=(1,)
6666 Shape of the action space. For discrete action spaces, use (1,).
67- evo_strategy_params : dict, default=None
68- Parameters for the evolution strategy.
67+ evo_strategy_kwargs : dict, default=None
68+ Parameters for the evolution strategy initialization. The population size and initial solution are set automatically.
69+ evo_strategy_default_params : dict, default=None
70+ Custom default parameters for the evolution strategy. If None, the default parameters are used.
6971 num_eval_steps : int, default=None
7072 Number of evaluation steps. If None, the evaluation runs until all episodes end.
7173
@@ -81,12 +83,16 @@ def __init__(
8183 population_size : int ,
8284 obs_space_shape : Shape ,
8385 act_space_shape : Shape ,
84- evo_strategy_params : dict = None ,
86+ evo_strategy_kwargs : dict = None ,
87+ evo_strategy_default_params : dict = None ,
8588 num_eval_steps : int = None
8689 ) -> None :
8790
88- if evo_strategy_params is None :
89- evo_strategy_params = {}
91+ if evo_strategy_kwargs is None :
92+ evo_strategy_kwargs = {}
93+
94+ if evo_strategy_default_params is None :
95+ evo_strategy_default_params = {}
9096
9197 self .obs_space_shape = obs_space_shape if jnp .ndim (obs_space_shape ) > 0 else (obs_space_shape ,)
9298 self .act_space_shape = act_space_shape if jnp .ndim (act_space_shape ) > 0 else (act_space_shape ,)
@@ -96,15 +102,16 @@ def __init__(
96102 variables = network .init (jax .random .key (0 ), x_dummy )
97103 num_params , params_format_fn = self .get_params_format_fn (variables )
98104
99- evo_strategy_params ['population_size' ] = population_size
100- evo_strategy_params ['solution' ] = jnp .zeros (num_params )
101- evo_strategy = evo_strategy (** evo_strategy_params )
105+ evo_strategy_kwargs ['population_size' ] = population_size
106+ evo_strategy_kwargs ['solution' ] = jnp .zeros (num_params )
107+ evo_strategy = evo_strategy (** evo_strategy_kwargs )
102108
103109 self .init = jax .jit (partial (
104110 self .init ,
105111 population_size = population_size ,
106112 variables = variables ,
107- evo_strategy = evo_strategy
113+ evo_strategy = evo_strategy ,
114+ evo_strategy_default_params = evo_strategy_default_params
108115 ))
109116 self .update = jax .jit (partial (
110117 self .update ,
@@ -168,7 +175,8 @@ def init(
168175 key : PRNGKey ,
169176 population_size : int ,
170177 variables : dict ,
171- evo_strategy : EvolutionaryAlgorithm
178+ evo_strategy : EvolutionaryAlgorithm ,
179+ evo_strategy_default_params : dict
172180 ) -> EvosaxState :
173181 r"""
174182 Initializes the evolution strategy state and the population. The fitness values, step counter,
@@ -184,6 +192,8 @@ def init(
184192 The initialized parameters of the agent network.
185193 evo_strategy : EvolutionaryAlgorithm
186194 Initialized evosax evolution strategy.
195+ evo_strategy_default_params : dict
196+ Custom default parameters for the evolution strategy.
187197
188198 Returns
189199 -------
@@ -192,6 +202,8 @@ def init(
192202 """
193203
194204 es_params = evo_strategy .default_params
205+ es_params = es_params .replace (** evo_strategy_default_params )
206+
195207 es_state = evo_strategy .init (key , variables , es_params )
196208 population , es_state = evo_strategy .ask (key , es_state , es_params )
197209
0 commit comments