|
11 | 11 | from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
|
12 | 12 | from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
|
13 | 13 | from dizoo.classic_control.cartpole.config.cartpole_ngu_config import cartpole_ngu_config, cartpole_ngu_create_config
|
| 14 | +from dizoo.classic_control.cartpole.config.cartpole_drex_dqn_config import cartpole_drex_dqn_config, cartpole_drex_dqn_create_config |
14 | 15 | from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
|
15 | 16 | serial_pipeline_reward_model_onpolicy
|
16 | 17 | from ding.entry.application_entry_trex_collect_data import trex_collecting_data
|
| 18 | +from ding.entry.application_entry_drex_collect_data import drex_collecting_data |
17 | 19 |
|
18 | 20 | cfg = [
|
19 | 21 | {
|
@@ -131,3 +133,38 @@ def test_trex():
|
131 | 133 | assert False, "pipeline fail"
|
132 | 134 | finally:
|
133 | 135 | os.popen('rm -rf test_serial_pipeline_trex*')
|
| 136 | + |
| 137 | + |
| 138 | +@pytest.mark.unittest |
| 139 | +def test_drex(): |
| 140 | + exp_name = 'test_serial_pipeline_drex_expert' |
| 141 | + config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] |
| 142 | + config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 |
| 143 | + config[0].exp_name = exp_name |
| 144 | + expert_policy = serial_pipeline(config, seed=0) |
| 145 | + |
| 146 | + exp_name = 'test_serial_pipeline_drex_collect' |
| 147 | + config = [deepcopy(cartpole_drex_dqn_config), deepcopy(cartpole_drex_dqn_create_config)] |
| 148 | + config[0].exp_name = exp_name |
| 149 | + config[0].reward_model.exp_name = exp_name |
| 150 | + config[0].reward_model.expert_model_path = 'test_serial_pipeline_drex_expert/ckpt/ckpt_best.pth.tar' |
| 151 | + config[0].reward_model.reward_model_path = 'test_serial_pipeline_drex_collect/cartpole.params' |
| 152 | + config[0].reward_model.offline_data_path = 'test_serial_pipeline_drex_collect' |
| 153 | + config[0].reward_model.checkpoint_max = 100 |
| 154 | + config[0].reward_model.checkpoint_step = 100 |
| 155 | + config[0].reward_model.num_snippets = 100 |
| 156 | + |
| 157 | + args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'}) |
| 158 | + args.cfg[0].policy.collect.n_episode = 8 |
| 159 | + del args.cfg[0].policy.collect.n_sample |
| 160 | + args.cfg[0].bc_iteration = 1000 # for unittest |
| 161 | + args.cfg[1].policy.type = 'bc' |
| 162 | + drex_collecting_data(args=args) |
| 163 | + try: |
| 164 | + serial_pipeline_reward_model_offpolicy( |
| 165 | + config, seed=0, max_train_iter=1, pretrain_reward=True, cooptrain_reward=False |
| 166 | + ) |
| 167 | + except Exception: |
| 168 | + assert False, "pipeline fail" |
| 169 | + finally: |
| 170 | + os.popen('rm -rf test_serial_pipeline_drex*') |
0 commit comments