Skip to content

Commit 0cc2149

Browse files
committed
add drex to new entry
1 parent 4c19aa3 commit 0cc2149

File tree

4 files changed

+59
-14
lines changed

4 files changed

+59
-14
lines changed

ding/entry/tests/test_application_entry_drex_collect_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ def test_drex_collecting_data():
7070
args.cfg[0].bc_iteration = 1000 # for unittest
7171
args.cfg[1].policy.type = 'bc'
7272
drex_collecting_data(args=args)
73-
os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
74-
os.popen('rm -rf {}'.format(args.cfg[0].reward_model.offline_data_path))
73+
#os.popen('rm -rf {}'.format(expert_policy_state_dict_path))
74+
#os.popen('rm -rf {}'.format(args.cfg[0].reward_model.offline_data_path))

ding/reward_model/drex_reward_model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
from easydict import EasyDict
33
import pickle
4+
import numpy as np
45

56
from ding.utils import REWARD_MODEL_REGISTRY
67

@@ -77,11 +78,26 @@ def load_expert_data(self) -> None:
7778
"""
7879
super(DrexRewardModel, self).load_expert_data()
7980

80-
with open(self.cfg.reward_model.offline_data_path + '/suboptimal_data.pkl', 'rb') as f:
81+
with open(self.cfg.offline_data_path + '/suboptimal_data.pkl', 'rb') as f:
8182
self.demo_data = pickle.load(f)
8283

8384
def train(self):
84-
self._train()
85+
# check if gpu available
86+
device = self.device # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
87+
# Assume that we are on a CUDA machine, then this should print a CUDA device:
88+
self._logger.info("device: {}".format(device))
89+
training_inputs, training_outputs = self.training_obs, self.training_labels
90+
91+
cum_loss = 0.0
92+
training_data = list(zip(training_inputs, training_outputs))
93+
for epoch in range(self.cfg.update_per_collect):
94+
np.random.shuffle(training_data)
95+
training_obs, training_labels = zip(*training_data)
96+
cum_loss = self._train(training_obs, training_labels)
97+
self.train_iter += 1
98+
self._logger.info("[epoch {}] loss {}".format(epoch, cum_loss))
99+
self.tb_logger.add_scalar("drex_reward/train_loss_iteration", cum_loss, self.train_iter)
100+
85101
return_dict = self.pred_data(self.demo_data)
86102
res, pred_returns = return_dict['real'], return_dict['pred']
87103
self._logger.info("real: " + str(res))

ding/reward_model/ngu_reward_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ class NGURewardModel(BaseRewardModel):
453453
)
454454

455455
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None:
456-
super(NGURewardModel).__init__()
456+
super(NGURewardModel, self).__init__()
457457
self.cfg = config
458458
self.tb_logger = tb_logger
459459
self.estimate_cnt = 0

dizoo/classic_control/cartpole/config/cartpole_drex_dqn_config.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,43 @@
1111
),
1212
reward_model=dict(
1313
type='drex',
14+
exp_name='cartpole_drex_dqn_seed0',
1415
min_snippet_length=5,
1516
max_snippet_length=100,
1617
checkpoint_min=0,
17-
checkpoint_max=1000,
18-
checkpoint_step=1000,
18+
checkpoint_max=760,
19+
checkpoint_step=760,
1920
learning_rate=1e-5,
2021
update_per_collect=1,
2122
# path to expert models that generate demonstration data
2223
# Users should add their own model path here. Model path should lead to an exp_name.
2324
# Absolute path is recommended.
2425
# In DI-engine, it is ``exp_name``.
2526
# For example, if you want to use dqn to generate demos, you can use ``spaceinvaders_dqn``
26-
expert_model_path='expert_model_path_placeholder',
27+
expert_model_path='cartpole_dqn_seed0/ckpt/ckpt_best.pth.tar',
2728
# path to save reward model
2829
# Users should add their own model path here.
2930
# Absolute path is recommended.
3031
# For example, if you use ``spaceinvaders_drex``, then the reward model will be saved in this directory.
31-
reward_model_path='reward_model_path_placeholder + ./spaceinvaders.params',
32+
reward_model_path='cartpole_drex_dqn_seed0/cartpole.params',
3233
# path to save generated observations.
3334
# Users should add their own model path here.
3435
# Absolute path is recommended.
3536
# For example, if you use ``spaceinvaders_drex``, then all the generated data will be saved in this directory.
36-
offline_data_path='offline_data_path_placeholder',
37+
offline_data_path='cartpole_drex_dqn_seed0',
3738
# path to pretrained bc model. If omitted, bc will be trained instead.
3839
# Users should add their own model path here. Model path should lead to a model ckpt.
3940
# Absolute path is recommended.
40-
bc_path='bc_path_placeholder',
41+
# bc_path='bc_path_placeholder',
4142
# list of noises
4243
eps_list=[0, 0.5, 1],
4344
num_trajs_per_bin=20,
45+
num_trajs=6,
46+
num_snippets=6000,
4447
bc_iterations=6000,
48+
hidden_size_list=[512, 64, 1],
49+
obs_shape=4,
50+
action_shape=2,
4551
),
4652
policy=dict(
4753
cuda=False,
@@ -57,7 +63,13 @@
5763
batch_size=64,
5864
learning_rate=0.001,
5965
),
60-
collect=dict(n_sample=8, collector=dict(get_train_sample=False, )),
66+
collect=dict(
67+
n_sample=8,
68+
collector=dict(
69+
get_train_sample=False,
70+
reward_shaping=False,
71+
),
72+
),
6173
eval=dict(evaluator=dict(eval_freq=40, )),
6274
other=dict(
6375
eps=dict(
@@ -66,7 +78,7 @@
6678
end=0.1,
6779
decay=10000,
6880
),
69-
replay_buffer=dict(replay_buffer_size=20000, ),
81+
replay_buffer=dict(replay_buffer_size=200000, ),
7082
),
7183
),
7284
)
@@ -79,7 +91,24 @@
7991
),
8092
env_manager=dict(type='subprocess'),
8193
policy=dict(type='dqn'),
82-
collector=dict(type='episode'),
8394
)
8495
cartpole_drex_dqn_create_config = EasyDict(cartpole_drex_dqn_create_config)
8596
create_config = cartpole_drex_dqn_create_config
97+
98+
if __name__ == "__main__":
99+
import argparse
100+
import torch
101+
from ding.config import read_config
102+
from ding.entry import drex_collecting_data
103+
from ding.entry import serial_pipeline_reward_model_offpolicy
104+
parser = argparse.ArgumentParser()
105+
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
106+
parser.add_argument('--seed', type=int, default=0)
107+
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
108+
args = parser.parse_args()
109+
args.cfg = read_config(args.cfg)
110+
args.cfg[1].policy.type = 'bc'
111+
args.cfg[0].policy.collect.n_episode = 8
112+
del args.cfg[0].policy.collect.n_sample
113+
drex_collecting_data(args)
114+
serial_pipeline_reward_model_offpolicy((main_config, create_config), pretrain_reward=True, cooptrain_reward=False)

0 commit comments

Comments
 (0)