Skip to content

Commit a6baf30

Browse files
committed
add unittest for drex
1 parent ff4de47 commit a6baf30

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

ding/entry/tests/test_serial_entry_reward_model.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
1212
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
1313
from dizoo.classic_control.cartpole.config.cartpole_ngu_config import cartpole_ngu_config, cartpole_ngu_create_config
14+
from dizoo.classic_control.cartpole.config.cartpole_drex_dqn_config import cartpole_drex_dqn_config, cartpole_drex_dqn_create_config
1415
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
1516
serial_pipeline_reward_model_onpolicy
1617
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
18+
from ding.entry.application_entry_drex_collect_data import drex_collecting_data
1719

1820
cfg = [
1921
{
@@ -131,3 +133,38 @@ def test_trex():
131133
assert False, "pipeline fail"
132134
finally:
133135
os.popen('rm -rf test_serial_pipeline_trex*')
136+
137+
138+
@pytest.mark.unittest
139+
def test_drex():
140+
exp_name = 'test_serial_pipeline_drex_expert'
141+
config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
142+
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
143+
config[0].exp_name = exp_name
144+
expert_policy = serial_pipeline(config, seed=0)
145+
146+
exp_name = 'test_serial_pipeline_drex_collect'
147+
config = [deepcopy(cartpole_drex_dqn_config), deepcopy(cartpole_drex_dqn_create_config)]
148+
config[0].exp_name = exp_name
149+
config[0].reward_model.exp_name = exp_name
150+
config[0].reward_model.expert_model_path = 'test_serial_pipeline_drex_expert/ckpt/ckpt_best.pth.tar'
151+
config[0].reward_model.reward_model_path = 'test_serial_pipeline_drex_collect/cartpole.params'
152+
config[0].reward_model.offline_data_path = 'test_serial_pipeline_drex_collect'
153+
config[0].reward_model.checkpoint_max = 100
154+
config[0].reward_model.checkpoint_step = 100
155+
config[0].reward_model.num_snippets = 100
156+
157+
args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
158+
args.cfg[0].policy.collect.n_episode = 8
159+
del args.cfg[0].policy.collect.n_sample
160+
args.cfg[0].bc_iteration = 1000 # for unittest
161+
args.cfg[1].policy.type = 'bc'
162+
drex_collecting_data(args=args)
163+
try:
164+
serial_pipeline_reward_model_offpolicy(
165+
config, seed=0, max_train_iter=1, pretrain_reward=True, cooptrain_reward=False
166+
)
167+
except Exception:
168+
assert False, "pipeline fail"
169+
finally:
170+
os.popen('rm -rf test_serial_pipeline_drex*')

dizoo/classic_control/cartpole/config/cartpole_dqn_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
update_per_collect=5,
2525
batch_size=64,
2626
learning_rate=0.001,
27+
learner=dict(hook=dict(save_ckpt_after_iter=1000)),
2728
),
2829
collect=dict(n_sample=8),
2930
eval=dict(evaluator=dict(eval_freq=40, )),

0 commit comments

Comments
 (0)