Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions arlbench/autorl/autorl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
"env_name": "CartPole-v1",
"env_kwargs": {},
"eval_env_kwargs": {},
"env_params": {},
"env_eval_params": {},
"n_envs": 10,
"algorithm": "dqn",
"cnn_policy": False,
Expand Down Expand Up @@ -107,6 +109,7 @@ def __init__(self, config: dict | None = None) -> None:
env_kwargs=self._config["env_kwargs"],
cnn_policy=self._config["cnn_policy"],
seed=self._seed,
env_params=self._config.get("env_params", None)
)

self._eval_env = make_env(
Expand All @@ -116,6 +119,7 @@ def __init__(self, config: dict | None = None) -> None:
env_kwargs=self._config["eval_env_kwargs"],
cnn_policy=self._config["cnn_policy"],
seed=self._seed + 1,
env_params=self._config.get("env_eval_params", None)
)

# Checkpointing
Expand Down
6 changes: 4 additions & 2 deletions arlbench/core/environments/gymnax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import gymnax
import jax
from dataclasses import replace

from .autorl_env import Environment

Expand All @@ -17,7 +18,7 @@ class GymnaxEnv(Environment):
"""A gymnax-based RL environment."""

def __init__(
self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None
self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None, env_params: dict[str, Any] | None = None
):
"""Creates a gymnax environment for JAX-based RL training.

Expand All @@ -29,7 +30,8 @@ def __init__(
"""
if env_kwargs is None:
env_kwargs = {}
env, env_params = gymnax.make(env_name, **env_kwargs)
env, og_env_params = gymnax.make(env_name, **env_kwargs)
env_params = replace(og_env_params, **(env_params or {}))
super().__init__(env_name, env, n_envs)

self.env_params = env_params
Expand Down
3 changes: 2 additions & 1 deletion arlbench/core/environments/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def make_env(
n_envs: int = 1,
seed: int = 0,
env_kwargs: dict[str, Any] | None = None,
env_params: dict[str, Any] | None = None,
) -> Environment | Wrapper:
"""ARLBench equivalent to make_env in gymnasium/gymnax etc.
Creates a JAX-compatible RL environment.
Expand Down Expand Up @@ -50,7 +51,7 @@ def make_env(
elif env_framework == "gymnax":
from .gymnax_env import GymnaxEnv

env = GymnaxEnv(env_name, n_envs, env_kwargs=env_kwargs)
env = GymnaxEnv(env_name, n_envs, env_kwargs=env_kwargs, env_params=env_params)
elif env_framework == "envpool":
from .envpool_env import EnvpoolEnv

Expand Down
1 change: 1 addition & 0 deletions examples/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ autorl:
env_framework: ${environment.framework}
env_name: ${environment.name}
env_kwargs: ${environment.kwargs}
env_params: ${environment.env_params}
eval_env_kwargs: ${environment.eval_kwargs}
n_envs: ${environment.n_envs}
algorithm: ${algorithm}
Expand Down
4 changes: 4 additions & 0 deletions examples/configs/environment/cc_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ cnn_policy: False
deterministic_eval: True
jax_enable_x64: False
n_envs: 8

env_params:
masspole: 0.9
length: 0.7
Loading