Skip to content

Commit fb3de72

Browse files
authored
Merge pull request #42 from automl/fix-init-on-step
This PR includes fixes for Issues #40 and #41 as well as the initial test cases that were previously missing
2 parents bd4c15d + 296d14d commit fb3de72

File tree

16 files changed

+1005
-4
lines changed

16 files changed

+1005
-4
lines changed

arlbench/autorl/autorl_env.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
import warnings
55
from collections.abc import Callable
6-
from typing import Any
6+
from typing import Any, Dict
77

88
import gymnasium
99
import jax
1010
import numpy as np
1111
import pandas as pd
12+
from omegaconf import OmegaConf
1213
from ConfigSpace import Configuration, ConfigurationSpace
1314

1415
from arlbench.core.algorithms import (
@@ -265,6 +266,37 @@ def _make_algorithm(self) -> Algorithm:
265266
cnn_policy=self._config["cnn_policy"],
266267
deterministic_eval=self._config["deterministic_eval"],
267268
)
269+
270+
def get_algorithm_init_kwargs(self, init_rng) -> Dict:
271+
"""Returns the algorithm initialization parameters.
272+
273+
Returns:
274+
Dict: Dictionary of algorithm initialization parameters.
275+
"""
276+
if isinstance(self._algorithm, PPO):
277+
return {"rng": init_rng, "network_params": self._algorithm_state.runner_state.train_state.params, "opt_state": self._algorithm_state.runner_state.train_state.opt_state}
278+
elif isinstance(self._algorithm, DQN):
279+
return{
280+
"rng": init_rng,
281+
"buffer_state": self._algorithm_state.buffer_state,
282+
"network_params": self._algorithm_state.runner_state.train_state.params,
283+
"target_params": self._algorithm_state.runner_state.train_state.target_params,
284+
"opt_state": self._algorithm_state.runner_state.train_state.opt_state,
285+
}
286+
elif isinstance(self._algorithm, SAC):
287+
return {
288+
"rng": init_rng,
289+
"buffer_state": self._algorithm_state.buffer_state,
290+
"actor_network_params": self._algorithm_state.runner_state.actor_train_state.params,
291+
"critic_network_params": self._algorithm_state.runner_state.critic_train_state.params,
292+
"critic_target_params": self._algorithm_state.runner_state.critic_train_state.target_params,
293+
"alpha_network_params": self._algorithm_state.runner_state.alpha_train_state.params,
294+
"actor_opt_state": self._algorithm_state.runner_state.actor_train_state.opt_state,
295+
"critic_opt_state": self._algorithm_state.runner_state.critic_train_state.opt_state,
296+
"alpha_opt_state": self._algorithm_state.runner_state.alpha_train_state.opt_state,
297+
}
298+
else:
299+
raise ValueError(f"Unsupported algorithm: {self._algorithm.name}")
268300

269301
def step(
270302
self,
@@ -304,7 +336,9 @@ def step(
304336

305337
# Apply changes to current hyperparameter configuration and reinstantiate algorithm
306338
if isinstance(action, dict):
307-
action = Configuration(self.config_space, action)
339+
action_config = dict(self._hpo_config)
340+
action_config.update(action)
341+
action = Configuration(self.config_space, action_config)
308342
self._hpo_config = action
309343

310344
seed = seed if seed else self._seed
@@ -325,6 +359,10 @@ def step(
325359
elif self._algorithm_state is None:
326360
init_rng = jax.random.key(seed)
327361
self._algorithm_state = self._algorithm.init(init_rng)
362+
else:
363+
init_rng = jax.random.key(seed)
364+
init_kwargs = self.get_algorithm_init_kwargs(init_rng)
365+
self._algorithm_state = self._algorithm.init(**init_kwargs)
328366

329367
# Training kwargs
330368
train_kw_args = {

arlbench/core/algorithms/sac/sac.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@ def init(
394394
_action = self.env.sample_actions(dummy_rng)
395395

396396
# for x64 enabled runs we have to explicitly cast the dummy action
397-
_action = jnp.array(_action, dtype=jnp.float64)
397+
dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32
398+
_action = jnp.array(_action, dtype=dtype)
398399

399400
_, (_obs, _reward, _done, _) = self.env.step(env_state, _action, dummy_rng)
400401

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"coverage==7.4.4",
2323
"chex==0.1.86",
2424
"xminigrid==0.8.0",
25+
"gymnasium==1.2.0",
2526
"ruff",
2627
"hydra-core",
2728
"hydra-submitit-launcher",
@@ -57,7 +58,7 @@ tooling = ["commitizen", "pre-commit", "ruff"]
5758
test = ["pytest", "pytest-coverage", "pytest-cases", "ARLBench[examples]"]
5859
examples = ["hypersweeper"]
5960
doc = [
60-
"automl_sphinx_theme", "gymnasium==0.29.1"
61+
"automl_sphinx_theme"
6162
]
6263
envpool = ["envpool==0.8.4"]
6364

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit test package for arlbench."""

tests/autorl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit test package for autorl subpackage."""

tests/autorl/test_autorl_env.py

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
from arlbench import AutoRLEnv
5+
from arlbench.core.algorithms import DQN
6+
7+
8+
def test_autorl_env_dqn_default_obs():
9+
config = {
10+
"seed": 42,
11+
"env_framework": "gymnax",
12+
"env_name": "CartPole-v1",
13+
"n_envs": 10,
14+
"algorithm": "dqn",
15+
"cnn_policy": False,
16+
"n_total_timesteps": 1e6,
17+
"n_eval_steps": 10,
18+
"checkpoint": [],
19+
"objectives": ["reward_mean"],
20+
"state_features": [],
21+
"n_steps": 10,
22+
}
23+
24+
env = AutoRLEnv(config=config)
25+
init_obs, _ = env.reset()
26+
assert len(init_obs.keys()) == 0
27+
28+
action = env.config_space.sample_configuration()
29+
obs, objectives, _, trunc, _ = env.step(action)
30+
assert len(obs.keys()) == 1
31+
assert obs["steps"].shape == (2,)
32+
assert trunc is False
33+
assert objectives["reward_mean"] > 0
34+
35+
36+
def test_autorl_env_dqn_grad_obs():
37+
config = {
38+
"seed": 42,
39+
"env_framework": "gymnax",
40+
"env_name": "CartPole-v1",
41+
"n_envs": 10,
42+
"algorithm": "dqn",
43+
"cnn_policy": False,
44+
"n_total_timesteps": 1e5,
45+
"n_eval_steps": 10,
46+
"checkpoint": [],
47+
"objectives": ["reward_mean"],
48+
"state_features": ["grad_info"],
49+
"n_steps": 10,
50+
}
51+
52+
env = AutoRLEnv(config=config)
53+
init_obs, _ = env.reset()
54+
assert len(init_obs.keys()) == 0
55+
56+
action = env.config_space.get_default_configuration()
57+
obs, objectives, _, trunc, _ = env.step(action)
58+
assert len(obs.keys()) == 2
59+
assert obs["steps"].shape == (2,)
60+
assert obs["grad_info"].shape == (2,)
61+
assert trunc is False
62+
assert objectives["reward_mean"] > 0
63+
64+
65+
def test_autorl_env_ppo_grad_obs():
66+
config = {
67+
"seed": 42,
68+
"env_framework": "gymnax",
69+
"env_name": "CartPole-v1",
70+
"n_envs": 10,
71+
"algorithm": "ppo",
72+
"cnn_policy": False,
73+
"n_total_timesteps": 1e5,
74+
"n_eval_steps": 10,
75+
"checkpoint": [],
76+
"objectives": ["reward_mean"],
77+
"state_features": ["grad_info"],
78+
"n_steps": 10,
79+
}
80+
81+
env = AutoRLEnv(config=config)
82+
init_obs, _ = env.reset()
83+
assert len(init_obs.keys()) == 0
84+
85+
action = env.config_space.get_default_configuration()
86+
obs, objectives, _, trunc, _ = env.step(action)
87+
assert len(obs.keys()) == 2
88+
assert obs["steps"].shape == (2,)
89+
assert obs["grad_info"].shape == (2,)
90+
assert trunc is False
91+
assert objectives["reward_mean"] > 0
92+
93+
94+
def test_autorl_env_sac_grad_obs():
95+
config = {
96+
"seed": 42,
97+
"env_framework": "gymnax",
98+
"env_name": "Pendulum-v1",
99+
"n_envs": 10,
100+
"algorithm": "sac",
101+
"cnn_policy": False,
102+
"n_total_timesteps": 5e4,
103+
"n_eval_steps": 10,
104+
"checkpoint": [],
105+
"objectives": ["reward_mean"],
106+
"state_features": ["grad_info"],
107+
"n_steps": 10,
108+
}
109+
110+
env = AutoRLEnv(config=config)
111+
init_obs, _ = env.reset()
112+
assert len(init_obs.keys()) == 0
113+
114+
action = env.config_space.get_default_configuration()
115+
obs, objectives, _, trunc, _ = env.step(action)
116+
assert len(obs.keys()) == 2
117+
assert obs["steps"].shape == (2,)
118+
assert obs["grad_info"].shape == (2,)
119+
assert trunc is False
120+
assert objectives["reward_mean"] > -2000
121+
122+
123+
def test_autorl_env_dqn_per_switch():
124+
config = {
125+
"seed": 42,
126+
"env_framework": "gymnax",
127+
"env_name": "CartPole-v1",
128+
"n_envs": 10,
129+
"algorithm": "dqn",
130+
"cnn_policy": False,
131+
"n_total_timesteps": 1e6,
132+
"n_eval_steps": 10,
133+
"checkpoint": [],
134+
"objectives": ["reward_mean"],
135+
"state_features": [],
136+
"n_steps": 10,
137+
}
138+
139+
env = AutoRLEnv(config)
140+
_, _ = env.reset()
141+
action = env.config_space.get_default_configuration()
142+
143+
action["buffer_prio_sampling"] = True
144+
_, objectives, _, _, _ = env.step(action)
145+
assert objectives["reward_mean"] > 100
146+
147+
action["buffer_prio_sampling"] = False
148+
_, objectives, _, _, _ = env.step(action)
149+
assert objectives["reward_mean"] > 150
150+
151+
action["buffer_prio_sampling"] = True
152+
_, objectives, _, _, _ = env.step(action)
153+
assert objectives["reward_mean"] > 200
154+
155+
_, _ = env.reset()
156+
action["buffer_prio_sampling"] = False
157+
_, objectives, _, _, _ = env.step(action)
158+
assert objectives["reward_mean"] > 200
159+
160+
action["buffer_prio_sampling"] = True
161+
_, objectives, _, _, _ = env.step(action)
162+
assert objectives["reward_mean"] > 200
163+
164+
action["buffer_prio_sampling"] = False
165+
_, objectives, _, _, _ = env.step(action)
166+
assert objectives["reward_mean"] > 200
167+
168+
169+
def test_autorl_env_dqn_dac():
170+
config = {
171+
"seed": 42,
172+
"env_framework": "gymnax",
173+
"env_name": "CartPole-v1",
174+
"n_envs": 10,
175+
"algorithm": "dqn",
176+
"cnn_policy": False,
177+
"n_total_timesteps": 1e6,
178+
"n_eval_steps": 10,
179+
"checkpoint": [],
180+
"objectives": ["reward_mean"],
181+
"state_features": [],
182+
"n_steps": 3,
183+
}
184+
185+
env = AutoRLEnv(config)
186+
# perform 3 HPO steps
187+
for _ in range(3):
188+
_, _ = env.reset()
189+
steps = 0
190+
trunc = False
191+
while not trunc:
192+
action = env.config_space.sample_configuration()
193+
194+
obs, objectives, _, trunc, _ = env.step(action)
195+
steps += 1
196+
assert len(obs.keys()) == 1
197+
assert obs["steps"].shape == (2,)
198+
assert objectives["reward_mean"] > 0
199+
assert trunc is True
200+
assert steps == 3
201+
202+
203+
def test_autorl_env_dqn_hpo():
204+
config = {
205+
"seed": 42,
206+
"env_framework": "gymnax",
207+
"env_name": "CartPole-v1",
208+
"n_envs": 10,
209+
"algorithm": "dqn",
210+
"cnn_policy": False,
211+
"n_total_timesteps": 1e5,
212+
"n_eval_steps": 10,
213+
"checkpoint": [],
214+
"objectives": ["reward_mean"],
215+
"state_features": [],
216+
"n_steps": 1, # Classic (static) HPO
217+
}
218+
219+
env = AutoRLEnv(config)
220+
221+
_, _ = env.reset()
222+
action = env.config_space.sample_configuration()
223+
obs, objectives, _, trunc, _ = env.step(action)
224+
assert len(obs.keys()) == 1
225+
assert obs["steps"].shape == (2,)
226+
assert objectives["reward_mean"] > 0
227+
assert trunc is True
228+
229+
230+
def test_autorl_env_step_before_reset():
231+
config = {
232+
"seed": 42,
233+
"env_framework": "gymnax",
234+
"env_name": "CartPole-v1",
235+
"n_envs": 10,
236+
"algorithm": "dqn",
237+
"cnn_policy": False,
238+
"n_total_timesteps": 1e6,
239+
"n_eval_steps": 10,
240+
"checkpoint": [],
241+
"objectives": ["reward_mean"],
242+
"state_features": [],
243+
"n_steps": 1, # Classic HPO
244+
}
245+
246+
env = AutoRLEnv(config)
247+
248+
with pytest.raises(ValueError) as excinfo:
249+
action = dict(DQN.get_hpo_config_space().sample_configuration())
250+
env.step(action)
251+
252+
assert "Called step() before reset()" in str(excinfo.value)
253+
254+
255+
def test_autorl_env_forbidden_step():
256+
config = {
257+
"seed": 42,
258+
"env_framework": "gymnax",
259+
"env_name": "CartPole-v1",
260+
"n_envs": 10,
261+
"algorithm": "dqn",
262+
"cnn_policy": False,
263+
"n_total_timesteps": 1e5,
264+
"n_eval_steps": 10,
265+
"checkpoint": [],
266+
"objectives": ["reward_mean"],
267+
"state_features": [],
268+
"n_steps": 1, # Classic HPO
269+
}
270+
271+
env = AutoRLEnv(config)
272+
env.reset()
273+
action = env.config_space.sample_configuration()
274+
env.step(action)
275+
276+
with pytest.raises(ValueError) as excinfo:
277+
env.step(action)
278+
279+
assert "Called step() before reset()" in str(excinfo.value)
280+
281+
282+
if __name__ == "__main__":
283+
test_autorl_env_dqn_per_switch()
284+

0 commit comments

Comments
 (0)