3
3
import os
4
4
import pickle
5
5
6
- from dizoo .classic_control .cartpole .config .cartpole_offppo_config import cartpole_offppo_config , \
7
- cartpole_offppo_create_config # noqa
6
+ from dizoo .classic_control .cartpole .config .cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config , \
7
+ cartpole_ppo_offpolicy_create_config # noqa
8
8
from dizoo .classic_control .cartpole .config .cartpole_trex_offppo_config import cartpole_trex_offppo_config ,\
9
9
cartpole_trex_offppo_create_config
10
10
from dizoo .classic_control .cartpole .envs import CartPoleEnv
15
15
16
16
@pytest .fixture (scope = 'module' )
17
17
def setup_state_dict ():
18
- config = deepcopy (cartpole_offppo_config ), deepcopy (cartpole_offppo_create_config )
18
+ config = deepcopy (cartpole_ppo_offpolicy_config ), deepcopy (cartpole_ppo_offpolicy_create_config )
19
19
try :
20
20
policy = serial_pipeline (config , seed = 0 )
21
21
except Exception :
@@ -31,12 +31,14 @@ def setup_state_dict():
31
31
class TestApplication :
32
32
33
33
def test_eval (self , setup_state_dict ):
34
- cfg_for_stop_value = compile_config (cartpole_offppo_config , auto = True , create_cfg = cartpole_offppo_create_config )
34
+ cfg_for_stop_value = compile_config (
35
+ cartpole_ppo_offpolicy_config , auto = True , create_cfg = cartpole_ppo_offpolicy_create_config
36
+ )
35
37
stop_value = cfg_for_stop_value .env .stop_value
36
- config = deepcopy (cartpole_offppo_config ), deepcopy (cartpole_offppo_create_config )
38
+ config = deepcopy (cartpole_ppo_offpolicy_config ), deepcopy (cartpole_ppo_offpolicy_create_config )
37
39
episode_return = eval (config , seed = 0 , state_dict = setup_state_dict ['eval' ])
38
40
assert episode_return >= stop_value
39
- config = deepcopy (cartpole_offppo_config ), deepcopy (cartpole_offppo_create_config )
41
+ config = deepcopy (cartpole_ppo_offpolicy_config ), deepcopy (cartpole_ppo_offpolicy_create_config )
40
42
episode_return = eval (
41
43
config ,
42
44
seed = 0 ,
@@ -46,7 +48,7 @@ def test_eval(self, setup_state_dict):
46
48
assert episode_return >= stop_value
47
49
48
50
def test_collect_demo_data (self , setup_state_dict ):
49
- config = deepcopy (cartpole_offppo_config ), deepcopy (cartpole_offppo_create_config )
51
+ config = deepcopy (cartpole_ppo_offpolicy_config ), deepcopy (cartpole_ppo_offpolicy_create_config )
50
52
collect_count = 16
51
53
expert_data_path = './expert.data'
52
54
collect_demo_data (
0 commit comments