Skip to content

Commit 3a9f213

Browse files
authored
polish(nyz): polish example demos (#568)
* polish(nyz): polish example demos * fix(nyz): fix unittest bugs * fix(nyz): fix trex unittest bugs
1 parent 886285d commit 3a9f213

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+100
-96
lines changed

ding/config/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def compile_config(
341341
create_cfg: dict = None,
342342
save_cfg: bool = True,
343343
save_path: str = 'total_config.py',
344+
renew_dir: bool = True,
344345
) -> EasyDict:
345346
"""
346347
Overview:
@@ -361,6 +362,7 @@ def compile_config(
361362
- create_cfg (:obj:`dict`): Input create config dict
362363
- save_cfg (:obj:`bool`): Save config or not
363364
- save_path (:obj:`str`): Path of saving file
365+
- renew_dir (:obj:`bool`): Whether to new a directory for saving config.
364366
Returns:
365367
- cfg (:obj:`EasyDict`): Config after compiling
366368
"""
@@ -460,7 +462,7 @@ def compile_config(
460462
if 'exp_name' not in cfg:
461463
cfg.exp_name = 'default_experiment'
462464
if save_cfg:
463-
if os.path.exists(cfg.exp_name):
465+
if os.path.exists(cfg.exp_name) and renew_dir:
464466
cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S")
465467
try:
466468
os.makedirs(cfg.exp_name)

ding/entry/serial_entry_preference_based_irl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def serial_pipeline_preference_based_irl(
4747
create_cfg.policy.type = create_cfg.policy.type + '_command'
4848
create_cfg.reward_model = dict(type=cfg.reward_model.type)
4949
env_fn = None if env_setting is None else env_setting[0]
50-
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
50+
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False)
5151
cfg_bak = copy.deepcopy(cfg)
5252
# Create main components: env, policy
5353
if env_setting is None:

ding/entry/serial_entry_preference_based_irl_onpolicy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def serial_pipeline_preference_based_irl_onpolicy(
4646
create_cfg.policy.type = create_cfg.policy.type + '_command'
4747
create_cfg.reward_model = dict(type=cfg.reward_model.type)
4848
env_fn = None if env_setting is None else env_setting[0]
49-
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
49+
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False)
5050
# Create main components: env, policy
5151
if env_setting is None:
5252
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)

ding/entry/tests/test_application_entry.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import os
44
import pickle
55

6-
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, \
7-
cartpole_offppo_create_config # noqa
6+
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
7+
cartpole_ppo_offpolicy_create_config # noqa
88
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
99
cartpole_trex_offppo_create_config
1010
from dizoo.classic_control.cartpole.envs import CartPoleEnv
@@ -15,7 +15,7 @@
1515

1616
@pytest.fixture(scope='module')
1717
def setup_state_dict():
18-
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
18+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
1919
try:
2020
policy = serial_pipeline(config, seed=0)
2121
except Exception:
@@ -31,12 +31,14 @@ def setup_state_dict():
3131
class TestApplication:
3232

3333
def test_eval(self, setup_state_dict):
34-
cfg_for_stop_value = compile_config(cartpole_offppo_config, auto=True, create_cfg=cartpole_offppo_create_config)
34+
cfg_for_stop_value = compile_config(
35+
cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config
36+
)
3537
stop_value = cfg_for_stop_value.env.stop_value
36-
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
38+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
3739
episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval'])
3840
assert episode_return >= stop_value
39-
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
41+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
4042
episode_return = eval(
4143
config,
4244
seed=0,
@@ -46,7 +48,7 @@ def test_eval(self, setup_state_dict):
4648
assert episode_return >= stop_value
4749

4850
def test_collect_demo_data(self, setup_state_dict):
49-
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
51+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
5052
collect_count = 16
5153
expert_data_path = './expert.data'
5254
collect_demo_data(

ding/entry/tests/test_application_entry_trex_collect_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
1010
cartpole_trex_offppo_create_config
11-
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config,\
12-
cartpole_offppo_create_config
11+
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
12+
cartpole_ppo_offpolicy_create_config
1313
from ding.entry.application_entry_trex_collect_data import collect_episodic_demo_data_for_trex, trex_collecting_data
1414
from ding.entry import serial_pipeline
1515

@@ -18,7 +18,7 @@
1818
def test_collect_episodic_demo_data_for_trex():
1919
exp_name = "test_collect_episodic_demo_data_for_trex_expert"
2020
expert_policy_state_dict_path = os.path.join(exp_name, 'expert_policy.pth.tar')
21-
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
21+
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
2222
config[0].exp_name = exp_name
2323
expert_policy = serial_pipeline(config, seed=0)
2424
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
@@ -41,7 +41,7 @@ def test_collect_episodic_demo_data_for_trex():
4141
@pytest.mark.unittest
4242
def test_trex_collecting_data():
4343
expert_policy_dir = 'test_trex_collecting_data_expert'
44-
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
44+
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
4545
config[0].exp_name = expert_policy_dir
4646
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
4747
serial_pipeline(config, seed=0)

ding/entry/tests/test_serial_entry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from dizoo.classic_control.cartpole.config.cartpole_dqn_stdim_config import cartpole_dqn_stdim_config, \
1010
cartpole_dqn_stdim_create_config
1111
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
12-
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, \
13-
cartpole_offppo_create_config
12+
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
13+
cartpole_ppo_offpolicy_create_config
1414
from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa
1515
from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa
1616
from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config # noqa
@@ -209,7 +209,7 @@ def test_qrdqn():
209209
@pytest.mark.platformtest
210210
@pytest.mark.unittest
211211
def test_ppo():
212-
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
212+
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
213213
config[0].policy.learn.update_per_collect = 1
214214
config[0].exp_name = 'ppo_offpolicy_unittest'
215215
try:
@@ -221,7 +221,7 @@ def test_ppo():
221221
@pytest.mark.platformtest
222222
@pytest.mark.unittest
223223
def test_ppo_nstep_return():
224-
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
224+
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
225225
config[0].policy.learn.update_per_collect = 1
226226
config[0].policy.nstep_return = True
227227
try:

ding/entry/tests/test_serial_entry_bc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ding.utils import POLICY_REGISTRY
1515
from ding.utils.data import default_collate, default_decollate
1616
from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config, \
17-
cartpole_offppo_config, cartpole_offppo_create_config
17+
cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config
1818
from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config
1919

2020

@@ -53,22 +53,22 @@ def _monitor_vars_learn(self) -> list:
5353
@pytest.mark.unittest
5454
def test_serial_pipeline_bc_ppo():
5555
# train expert policy
56-
train_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
56+
train_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
5757
train_config[0].exp_name = 'test_serial_pipeline_bc_ppo'
5858
expert_policy = serial_pipeline(train_config, seed=0)
5959

6060
# collect expert demo data
6161
collect_count = 10000
6262
expert_data_path = 'expert_data_ppo_bc.pkl'
6363
state_dict = expert_policy.collect_mode.state_dict()
64-
collect_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
64+
collect_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
6565
collect_config[0].exp_name = 'test_serial_pipeline_bc_ppo_collect'
6666
collect_demo_data(
6767
collect_config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
6868
)
6969

7070
# il training 1
71-
il_config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
71+
il_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
7272
il_config[0].policy.eval.evaluator.multi_gpu = False
7373
il_config[0].policy.learn.train_epoch = 20
7474
il_config[1].policy.type = 'ppo_bc'

ding/entry/tests/test_serial_entry_preference_based_irl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from ding.entry import serial_pipeline_preference_based_irl
1010
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
1111
cartpole_trex_offppo_create_config
12-
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config,\
13-
cartpole_offppo_create_config
12+
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
13+
cartpole_ppo_offpolicy_create_config
1414
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
1515
from ding.reward_model.trex_reward_model import TrexConvEncoder
1616
from ding.torch_utils import is_differentiable
@@ -19,16 +19,14 @@
1919
@pytest.mark.unittest
2020
def test_serial_pipeline_trex():
2121
exp_name = 'test_serial_pipeline_trex_expert'
22-
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
22+
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
2323
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
2424
config[0].exp_name = exp_name
2525
expert_policy = serial_pipeline(config, seed=0)
2626

2727
exp_name = 'test_serial_pipeline_trex_collect'
2828
config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)]
2929
config[0].exp_name = exp_name
30-
config[0].reward_model.data_path = exp_name
31-
config[0].reward_model.reward_model_path = exp_name + '/cartpole.params'
3230
config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_expert'
3331
config[0].reward_model.checkpoint_max = 100
3432
config[0].reward_model.checkpoint_step = 100

ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ def test_serial_pipeline_trex_onpolicy():
2424
exp_name = 'test_serial_pipeline_trex_onpolicy_collect'
2525
config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)]
2626
config[0].exp_name = exp_name
27-
config[0].reward_model.data_path = exp_name
28-
config[0].reward_model.reward_model_path = exp_name + '/cartpole.params'
2927
config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_onpolicy_expert'
3028
config[0].reward_model.checkpoint_max = 100
3129
config[0].reward_model.checkpoint_step = 100

ding/entry/tests/test_serial_entry_reward_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from copy import deepcopy
66

77
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
8-
from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config, cartpole_offppo_create_config # noqa
8+
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
99
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
1010
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
1111
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
@@ -44,13 +44,13 @@
4444
@pytest.mark.parametrize('reward_model_config', cfg)
4545
def test_irl(reward_model_config):
4646
reward_model_config = EasyDict(reward_model_config)
47-
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
47+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
4848
expert_policy = serial_pipeline(config, seed=0, max_train_iter=2)
4949
# collect expert demo data
5050
collect_count = 10000
5151
expert_data_path = 'expert_data.pkl'
5252
state_dict = expert_policy.collect_mode.state_dict()
53-
config = deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)
53+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
5454
collect_demo_data(
5555
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
5656
)

0 commit comments

Comments
 (0)