Skip to content

Commit 2ea4b34

Browse files
committed
reorganize tests
1 parent 0457845 commit 2ea4b34

File tree

4 files changed

+79
-36
lines changed

4 files changed

+79
-36
lines changed

test/test_atari.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
import pytest
3+
from os import environ
4+
from rl.experiment import run
5+
from . import conftest
6+
import pandas as pd
7+
8+
9+
class AtariTest(unittest.TestCase):
10+
11+
@unittest.skipIf(environ.get('CI'),
12+
"Delay CI test until dev stable")
13+
@classmethod
14+
def test_breakout_dqn(cls):
15+
data_df = run('breakout_dqn')
16+
assert isinstance(data_df, pd.DataFrame)
17+
18+
@unittest.skipIf(environ.get('CI'),
19+
"Delay CI test until dev stable")
20+
@classmethod
21+
def test_breakout_double_dqn(cls):
22+
data_df = run('breakout_double_dqn')
23+
assert isinstance(data_df, pd.DataFrame)

test/test_box2d.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import unittest
2+
import pytest
3+
from os import environ
4+
from rl.experiment import run
5+
from . import conftest
6+
import pandas as pd
7+
8+
9+
class Box2DTest(unittest.TestCase):
10+
11+
@classmethod
12+
def test_lunar_dqn(cls):
13+
data_df = run('lunar_dqn')
14+
assert isinstance(data_df, pd.DataFrame)
15+
16+
@classmethod
17+
def test_lunar_double_dqn(cls):
18+
data_df = run('lunar_double_dqn')
19+
assert isinstance(data_df, pd.DataFrame)
20+
21+
@classmethod
22+
def test_walker_ddpg_linearnoise(cls):
23+
data_df = run('walker_ddpg_linearnoise')
24+
assert isinstance(data_df, pd.DataFrame)
Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,58 +6,64 @@
66
import pandas as pd
77

88

9-
class AdvancedTest(unittest.TestCase):
9+
class ClassicTest(unittest.TestCase):
1010

1111
@classmethod
12-
def test_sarsa(cls):
13-
data_df = run('sarsa')
12+
def test_quickstart_dqn(cls):
13+
data_df = run('quickstart_dqn')
1414
assert isinstance(data_df, pd.DataFrame)
1515

1616
@classmethod
17-
def test_exp_sarsa(cls):
18-
data_df = run('exp_sarsa')
17+
def test_dqn_epsilon(cls):
18+
data_df = run('dqn_epsilon')
1919
assert isinstance(data_df, pd.DataFrame)
2020

2121
@classmethod
22-
def test_offpol_sarsa(cls):
23-
data_df = run('offpol_sarsa')
22+
def test_dqn(cls):
23+
data_df = run('dqn')
2424
assert isinstance(data_df, pd.DataFrame)
2525

2626
@classmethod
27-
def test_acrobot(cls):
28-
data_df = run('acrobot')
27+
def test_double_dqn(cls):
28+
data_df = run('double_dqn')
2929
assert isinstance(data_df, pd.DataFrame)
3030

3131
@classmethod
32-
def test_mountain_dqn(cls):
33-
data_df = run('mountain_dqn')
32+
def test_sarsa(cls):
33+
data_df = run('sarsa')
34+
assert isinstance(data_df, pd.DataFrame)
35+
36+
@classmethod
37+
def test_exp_sarsa(cls):
38+
data_df = run('exp_sarsa')
3439
assert isinstance(data_df, pd.DataFrame)
3540

3641
@classmethod
37-
def test_lunar_dqn(cls):
38-
data_df = run('lunar_dqn')
42+
def test_offpol_sarsa(cls):
43+
data_df = run('offpol_sarsa')
3944
assert isinstance(data_df, pd.DataFrame)
4045

41-
@unittest.skipIf(environ.get('CI'),
42-
"Delay CI test until dev stable")
4346
@classmethod
44-
def test_breakout_dqn(cls):
45-
data_df = run('breakout_dqn')
47+
def test_cartpole_ac_argmax(cls):
48+
data_df = run('cartpole_ac_argmax')
4649
assert isinstance(data_df, pd.DataFrame)
4750

48-
@unittest.skipIf(environ.get('CI'),
49-
"Delay CI test until dev stable")
5051
@classmethod
51-
def test_breakout_double_dqn(cls):
52-
data_df = run('breakout_double_dqn')
52+
def test_dqn_v1(cls):
53+
data_df = run('dqn_v1')
5354
assert isinstance(data_df, pd.DataFrame)
5455

5556
@classmethod
56-
def test_cartpole_ac_argmax(cls):
57-
data_df = run('cartpole_ac_argmax')
57+
def test_acrobot(cls):
58+
data_df = run('acrobot')
5859
assert isinstance(data_df, pd.DataFrame)
5960

6061
@classmethod
61-
def test_pendulum_ddpg(cls):
62-
data_df = run('pendulum_ddpg')
62+
def test_pendulum_ddpg_linearnoise(cls):
63+
data_df = run('pendulum_ddpg_linearnoise')
64+
assert isinstance(data_df, pd.DataFrame)
65+
66+
@classmethod
67+
def test_mountain_dqn(cls):
68+
data_df = run('mountain_dqn')
6369
assert isinstance(data_df, pd.DataFrame)
Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77

88

9-
class BasicTest(unittest.TestCase):
9+
class DevTest(unittest.TestCase):
1010

1111
@classmethod
1212
def test_clean_import(cls):
@@ -46,13 +46,3 @@ def test_dqn_pass(cls):
4646
# def test_dqn_random_search(cls):
4747
# data_df = run('test_dqn_random_search', param_selection=True)
4848
# assert isinstance(data_df, pd.DataFrame)
49-
50-
@classmethod
51-
def test_dqn(cls):
52-
data_df = run('dqn')
53-
assert isinstance(data_df, pd.DataFrame)
54-
55-
@classmethod
56-
def test_double_dqn(cls):
57-
data_df = run('double_dqn')
58-
assert isinstance(data_df, pd.DataFrame)

0 commit comments

Comments
 (0)