|
11 | 11 | ),
|
12 | 12 | reward_model=dict(
|
13 | 13 | type='drex',
|
| 14 | + exp_name='cartpole_drex_dqn_seed0', |
14 | 15 | min_snippet_length=5,
|
15 | 16 | max_snippet_length=100,
|
16 | 17 | checkpoint_min=0,
|
17 |
| - checkpoint_max=1000, |
18 |
| - checkpoint_step=1000, |
| 18 | + checkpoint_max=760, |
| 19 | + checkpoint_step=760, |
19 | 20 | learning_rate=1e-5,
|
20 | 21 | update_per_collect=1,
|
21 | 22 | # path to expert models that generate demonstration data
|
22 | 23 | # Users should add their own model path here. Model path should lead to an exp_name.
|
23 | 24 | # Absolute path is recommended.
|
24 | 25 | # In DI-engine, it is ``exp_name``.
|
25 | 26 | # For example, if you want to use dqn to generate demos, you can use ``spaceinvaders_dqn``
|
26 |
| - expert_model_path='expert_model_path_placeholder', |
| 27 | + expert_model_path='cartpole_dqn_seed0/ckpt/ckpt_best.pth.tar', |
27 | 28 | # path to save reward model
|
28 | 29 | # Users should add their own model path here.
|
29 | 30 | # Absolute path is recommended.
|
30 | 31 | # For example, if you use ``spaceinvaders_drex``, then the reward model will be saved in this directory.
|
31 |
| - reward_model_path='reward_model_path_placeholder + ./spaceinvaders.params', |
| 32 | + reward_model_path='cartpole_drex_dqn_seed0/cartpole.params', |
32 | 33 | # path to save generated observations.
|
33 | 34 | # Users should add their own model path here.
|
34 | 35 | # Absolute path is recommended.
|
35 | 36 | # For example, if you use ``spaceinvaders_drex``, then all the generated data will be saved in this directory.
|
36 |
| - offline_data_path='offline_data_path_placeholder', |
| 37 | + offline_data_path='cartpole_drex_dqn_seed0', |
37 | 38 | # path to pretrained bc model. If omitted, bc will be trained instead.
|
38 | 39 | # Users should add their own model path here. Model path should lead to a model ckpt.
|
39 | 40 | # Absolute path is recommended.
|
40 |
| - bc_path='bc_path_placeholder', |
| 41 | + # bc_path='bc_path_placeholder', |
41 | 42 | # list of noises
|
42 | 43 | eps_list=[0, 0.5, 1],
|
43 | 44 | num_trajs_per_bin=20,
|
| 45 | + num_trajs=6, |
| 46 | + num_snippets=6000, |
44 | 47 | bc_iterations=6000,
|
| 48 | + hidden_size_list=[512, 64, 1], |
| 49 | + obs_shape=4, |
| 50 | + action_shape=2, |
45 | 51 | ),
|
46 | 52 | policy=dict(
|
47 | 53 | cuda=False,
|
|
57 | 63 | batch_size=64,
|
58 | 64 | learning_rate=0.001,
|
59 | 65 | ),
|
60 |
| - collect=dict(n_sample=8, collector=dict(get_train_sample=False, )), |
| 66 | + collect=dict( |
| 67 | + n_sample=8, |
| 68 | + collector=dict( |
| 69 | + get_train_sample=False, |
| 70 | + reward_shaping=False, |
| 71 | + ), |
| 72 | + ), |
61 | 73 | eval=dict(evaluator=dict(eval_freq=40, )),
|
62 | 74 | other=dict(
|
63 | 75 | eps=dict(
|
|
66 | 78 | end=0.1,
|
67 | 79 | decay=10000,
|
68 | 80 | ),
|
69 |
| - replay_buffer=dict(replay_buffer_size=20000, ), |
| 81 | + replay_buffer=dict(replay_buffer_size=200000, ), |
70 | 82 | ),
|
71 | 83 | ),
|
72 | 84 | )
|
|
79 | 91 | ),
|
80 | 92 | env_manager=dict(type='subprocess'),
|
81 | 93 | policy=dict(type='dqn'),
|
82 |
| - collector=dict(type='episode'), |
83 | 94 | )
|
84 | 95 | cartpole_drex_dqn_create_config = EasyDict(cartpole_drex_dqn_create_config)
|
85 | 96 | create_config = cartpole_drex_dqn_create_config
|
| 97 | + |
| 98 | +if __name__ == "__main__": |
| 99 | + import argparse |
| 100 | + import torch |
| 101 | + from ding.config import read_config |
| 102 | + from ding.entry import drex_collecting_data |
| 103 | + from ding.entry import serial_pipeline_reward_model_offpolicy |
| 104 | + parser = argparse.ArgumentParser() |
| 105 | + parser.add_argument('--cfg', type=str, default='please enter abs path for this file') |
| 106 | + parser.add_argument('--seed', type=int, default=0) |
| 107 | + parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| 108 | + args = parser.parse_args() |
| 109 | + args.cfg = read_config(args.cfg) |
| 110 | + args.cfg[1].policy.type = 'bc' |
| 111 | + args.cfg[0].policy.collect.n_episode = 8 |
| 112 | + del args.cfg[0].policy.collect.n_sample |
| 113 | + drex_collecting_data(args) |
| 114 | + serial_pipeline_reward_model_offpolicy((main_config, create_config), pretrain_reward=True, cooptrain_reward=False) |
0 commit comments