Skip to content

Commit 5b4e4cc

Browse files
committed
fix unit test for trex and gail
1 parent 0cc2149 commit 5b4e4cc

File tree

4 files changed

+36
-32
lines changed

4 files changed

+36
-32
lines changed

ding/entry/tests/test_serial_entry_reward_model.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,6 @@
3838
'hidden_size_list': [64, 1],
3939
'update_per_collect': 200,
4040
'batch_size': 128,
41-
}, {
42-
'type': 'trex',
43-
'exp_name': 'cartpole_trex_offppo_seed0',
44-
'min_snippet_length': 5,
45-
'max_snippet_length': 100,
46-
'checkpoint_min': 0,
47-
'checkpoint_max': 6,
48-
'checkpoint_step': 6,
49-
'learning_rate': 1e-5,
50-
'update_per_collect': 1,
51-
'expert_model_path': 'cartpole_ppo_offpolicy_seed0',
52-
'hidden_size_list': [512, 64, 1],
53-
'obs_shape': 4,
54-
'action_shape': 2,
5541
}
5642
]
5743

@@ -67,15 +53,9 @@ def test_irl(reward_model_config):
6753
expert_data_path = 'expert_data.pkl'
6854
state_dict = expert_policy.collect_mode.state_dict()
6955
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
70-
if reward_model_config.type == 'trex':
71-
trex_config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)]
72-
trex_config[0].reward_model = reward_model_config
73-
args = EasyDict({'cfg': deepcopy(trex_config), 'seed': 0, 'device': 'cpu'})
74-
trex_collecting_data(args=args)
75-
else:
76-
collect_demo_data(
77-
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
78-
)
56+
collect_demo_data(
57+
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
58+
)
7959
# irl + rl training
8060
cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config)
8161
cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config)
@@ -88,9 +68,6 @@ def test_irl(reward_model_config):
8868
cp_cartpole_dqn_config.policy.collect.n_sample = 128
8969
cooptrain_reward = True
9070
pretrain_reward = False
91-
if reward_model_config.type == 'trex':
92-
cooptrain_reward = False
93-
pretrain_reward = True
9471
serial_pipeline_reward_model_offpolicy(
9572
(cp_cartpole_dqn_config, cp_cartpole_dqn_create_config),
9673
seed=0,
@@ -126,3 +103,31 @@ def test_ngu():
126103
serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2)
127104
except Exception:
128105
assert False, "pipeline fail"
106+
107+
108+
@pytest.mark.unittest
109+
def test_trex():
110+
exp_name = 'test_serial_pipeline_trex_expert'
111+
config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
112+
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
113+
config[0].exp_name = exp_name
114+
expert_policy = serial_pipeline(config, seed=0)
115+
116+
exp_name = 'test_serial_pipeline_trex_collect'
117+
config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)]
118+
config[0].exp_name = exp_name
119+
config[0].reward_model.exp_name = exp_name
120+
config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_expert'
121+
config[0].reward_model.checkpoint_max = 100
122+
config[0].reward_model.checkpoint_step = 100
123+
config[0].reward_model.num_snippets = 100
124+
args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
125+
trex_collecting_data(args=args)
126+
try:
127+
serial_pipeline_reward_model_offpolicy(
128+
config, seed=0, max_train_iter=1, pretrain_reward=True, cooptrain_reward=False
129+
)
130+
except Exception:
131+
assert False, "pipeline fail"
132+
finally:
133+
os.popen('rm -rf test_serial_pipeline_trex*')

ding/reward_model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# exploration
1212
from .rnd_reward_model import RndRewardModel
1313
from .guided_cost_reward_model import GuidedCostRewardModel
14-
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
14+
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel, NGURewardModel
1515
from .icm_reward_model import ICMRewardModel
1616
from .network import RepresentationNetwork, RNDNetwork, REDNetwork, GAILNetwork, ICMNetwork, GCLNetwork, TREXNetwork
1717
from .reword_model_utils import concat_state_action_pairs, combine_intrinsic_exterinsic_reward, obs_norm, collect_states

ding/reward_model/red_irl_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from typing import Dict, List
22
import pickle
33
import random
4-
from collections.abc import Iterable
54

6-
import torch
75
import torch.optim as optim
8-
import torch.nn.functional as F
96

107
from ding.utils import REWARD_MODEL_REGISTRY, one_time_warning
118
from .base_reward_model import BaseRewardModel

ding/reward_model/tests/test_gail_irl_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
learning_rate=1e-3,
2929
update_per_collect=2,
3030
data_path=expert_data_path_1d,
31+
clear_buffer_per_iters=1,
3132
),
3233

3334
cfg2 = dict(
@@ -40,6 +41,7 @@
4041
update_per_collect=2,
4142
data_path=expert_data_path_3d,
4243
action_size=action_space,
44+
clear_buffer_per_iters=1,
4345
),
4446

4547
# create fake expert dataset
@@ -77,7 +79,7 @@ def test_dataset_1d(cfg):
7779
policy.train()
7880
train_data_augmented = policy.estimate(data)
7981
assert 'reward' in train_data_augmented[0].keys()
80-
policy.clear_data()
82+
policy.clear_data(iter=1)
8183
assert len(policy.train_data) == 0
8284
os.popen('rm -rf {}'.format(expert_data_path_1d))
8385

@@ -101,6 +103,6 @@ def test_dataset_3d(cfg):
101103
policy.train()
102104
train_data_augmented = policy.estimate(data)
103105
assert 'reward' in train_data_augmented[0].keys()
104-
policy.clear_data()
106+
policy.clear_data(iter=1)
105107
assert len(policy.train_data) == 0
106108
os.popen('rm -rf {}'.format(expert_data_path_3d))

0 commit comments

Comments
 (0)