diff --git a/arlbench/autorl/autorl_env.py b/arlbench/autorl/autorl_env.py index a15523f07..4cb103a20 100644 --- a/arlbench/autorl/autorl_env.py +++ b/arlbench/autorl/autorl_env.py @@ -37,6 +37,8 @@ "env_name": "CartPole-v1", "env_kwargs": {}, "eval_env_kwargs": {}, + "env_params": {}, + "env_eval_params": {}, "n_envs": 10, "algorithm": "dqn", "cnn_policy": False, @@ -107,6 +109,7 @@ def __init__(self, config: dict | None = None) -> None: env_kwargs=self._config["env_kwargs"], cnn_policy=self._config["cnn_policy"], seed=self._seed, + env_params=self._config.get("env_params", None) ) self._eval_env = make_env( @@ -116,6 +119,7 @@ def __init__(self, config: dict | None = None) -> None: env_kwargs=self._config["eval_env_kwargs"], cnn_policy=self._config["cnn_policy"], seed=self._seed + 1, + env_params=self._config.get("env_eval_params", None) ) # Checkpointing diff --git a/arlbench/core/environments/gymnax_env.py b/arlbench/core/environments/gymnax_env.py index 120974de1..c64277815 100644 --- a/arlbench/core/environments/gymnax_env.py +++ b/arlbench/core/environments/gymnax_env.py @@ -6,6 +6,7 @@ import gymnax import jax +from dataclasses import replace from .autorl_env import Environment @@ -17,7 +18,7 @@ class GymnaxEnv(Environment): """A gymnax-based RL environment.""" def __init__( - self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None + self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None, env_params: dict[str, Any] | None = None ): """Creates a gymnax environment for JAX-based RL training. @@ -29,7 +30,8 @@ def __init__( """ if env_kwargs is None: env_kwargs = {} - env, env_params = gymnax.make(env_name, **env_kwargs) + env, og_env_params = gymnax.make(env_name, **env_kwargs) + env_params = replace(og_env_params, **(env_params or {})) super().__init__(env_name, env, n_envs) self.env_params = env_params diff --git a/arlbench/core/environments/make_env.py b/arlbench/core/environments/make_env.py index 1ffb4c7e6..00a645a1e 100644 --- a/arlbench/core/environments/make_env.py +++ b/arlbench/core/environments/make_env.py @@ -19,6 +19,7 @@ def make_env( n_envs: int = 1, seed: int = 0, env_kwargs: dict[str, Any] | None = None, + env_params: dict[str, Any] | None = None, ) -> Environment | Wrapper: """ARLBench equivalent to make_env in gymnasium/gymnax etc. Creates a JAX-compatible RL environment. @@ -50,7 +51,7 @@ def make_env( elif env_framework == "gymnax": from .gymnax_env import GymnaxEnv - env = GymnaxEnv(env_name, n_envs, env_kwargs=env_kwargs) + env = GymnaxEnv(env_name, n_envs, env_kwargs=env_kwargs, env_params=env_params) elif env_framework == "envpool": from .envpool_env import EnvpoolEnv diff --git a/examples/configs/base.yaml b/examples/configs/base.yaml index 3445ed89e..3770d1e23 100644 --- a/examples/configs/base.yaml +++ b/examples/configs/base.yaml @@ -20,6 +20,7 @@ autorl: env_framework: ${environment.framework} env_name: ${environment.name} env_kwargs: ${environment.kwargs} + env_params: ${environment.env_params} eval_env_kwargs: ${environment.eval_kwargs} n_envs: ${environment.n_envs} algorithm: ${algorithm} diff --git a/examples/configs/environment/cc_cartpole.yaml b/examples/configs/environment/cc_cartpole.yaml index d789c0dfb..d478d18bd 100644 --- a/examples/configs/environment/cc_cartpole.yaml +++ b/examples/configs/environment/cc_cartpole.yaml @@ -7,3 +7,7 @@ cnn_policy: False deterministic_eval: True jax_enable_x64: False n_envs: 8 + +env_params: + masspole: 0.9 + length: 0.7