Skip to content

Commit 21d8578

Browse files
committed
add more tests
1 parent 2ea4b34 commit 21d8578

File tree

4 files changed

+37
-4
lines changed

4 files changed

+37
-4
lines changed

rl/spec/classic_experiment_specs.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,33 @@
7676
]
7777
}
7878
},
79+
"dqn_per": {
80+
"problem": "CartPole-v0",
81+
"Agent": "DQN",
82+
"HyperOptimizer": "GridSearch",
83+
"Memory": "PrioritizedExperienceReplay",
84+
"Optimizer": "AdamOptimizer",
85+
"Policy": "BoltzmannPolicy",
86+
"PreProcessor": "NoPreProcessor",
87+
"param": {
88+
"lr": 0.02,
89+
"gamma": 0.99,
90+
"hidden_layers": [64],
91+
"hidden_layers_activation": "sigmoid",
92+
"exploration_anneal_episodes": 10
93+
},
94+
"param_range": {
95+
"lr": [0.001, 0.005, 0.01, 0.02],
96+
"gamma": [0.95, 0.97, 0.99, 0.999],
97+
"hidden_layers": [
98+
[16],
99+
[32],
100+
[64],
101+
[16, 8],
102+
[32, 16]
103+
]
104+
}
105+
},
79106
"rand_dqn": {
80107
"problem": "CartPole-v0",
81108
"Agent": "DQN",

test/test_atari.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88

99
class AtariTest(unittest.TestCase):
1010

11-
@unittest.skipIf(environ.get('CI'),
12-
"Delay CI test until dev stable")
1311
@classmethod
1412
def test_breakout_dqn(cls):
1513
data_df = run('breakout_dqn')
1614
assert isinstance(data_df, pd.DataFrame)
1715

18-
@unittest.skipIf(environ.get('CI'),
19-
"Delay CI test until dev stable")
2016
@classmethod
2117
def test_breakout_double_dqn(cls):
2218
data_df = run('breakout_double_dqn')

test/test_box2d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ def test_lunar_double_dqn(cls):
1818
data_df = run('lunar_double_dqn')
1919
assert isinstance(data_df, pd.DataFrame)
2020

21+
@classmethod
22+
def test_lunar_freeze(cls):
23+
data_df = run('lunar_freeze')
24+
assert isinstance(data_df, pd.DataFrame)
25+
2126
@classmethod
2227
def test_walker_ddpg_linearnoise(cls):
2328
data_df = run('walker_ddpg_linearnoise')

test/test_classic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def test_dqn(cls):
2323
data_df = run('dqn')
2424
assert isinstance(data_df, pd.DataFrame)
2525

26+
@classmethod
27+
def test_dqn_per(cls):
28+
data_df = run('dqn_per')
29+
assert isinstance(data_df, pd.DataFrame)
30+
2631
@classmethod
2732
def test_double_dqn(cls):
2833
data_df = run('double_dqn')

0 commit comments

Comments
 (0)