Skip to content

Commit 0c48c08

Browse files
committed
refactor trex config file
1 parent f099cac commit 0c48c08

17 files changed

+87
-84
lines changed

dizoo/atari/config/serial/pong/pong_trex_offppo_config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
),
1414
reward_model=dict(
1515
type='trex',
16+
exp_name='pong_trex_offppo_seed0',
1617
min_snippet_length=50,
1718
max_snippet_length=100,
1819
checkpoint_min=0,
@@ -24,7 +25,7 @@
2425
# Absolute path is recommended.
2526
# In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
2627
# However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
27-
expert_model_path='model_path_placeholder',
28+
expert_model_path='pong_ppo_seed0',
2829
hidden_size_list=[512, 64, 1],
2930
obs_shape=[4, 84, 84],
3031
action_shape=6,
@@ -76,6 +77,7 @@
7677
),
7778
env_manager=dict(type='subprocess'),
7879
policy=dict(type='ppo_offpolicy'),
80+
reward_model=dict(type='trex'),
7981
)
8082
pong_trex_ppo_create_config = EasyDict(pong_trex_ppo_create_config)
8183
create_config = pong_trex_ppo_create_config
@@ -87,12 +89,12 @@
8789
import argparse
8890
import torch
8991
from ding.entry import trex_collecting_data
90-
from ding.entry import serial_pipeline_preference_based_irl
92+
from ding.entry import serial_pipeline_reward_model_offpolicy
9193
parser = argparse.ArgumentParser()
9294
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
9395
parser.add_argument('--seed', type=int, default=0)
9496
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
9597
args = parser.parse_args()
9698
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
9799
trex_collecting_data(args)
98-
serial_pipeline_preference_based_irl((main_config, create_config))
100+
serial_pipeline_reward_model_offpolicy((main_config, create_config), pretrain_reward=True, cooptrain_reward=False)

dizoo/atari/config/serial/pong/pong_trex_sql_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# Absolute path is recommended.
2626
# In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
2727
# However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
28-
expert_model_path='model_path_placeholder',
28+
expert_model_path='pong_sql_seed0',
2929
hidden_size_list=[512, 64, 1],
3030
obs_shape=[4, 84, 84],
3131
action_shape=6,
@@ -62,6 +62,7 @@
6262
),
6363
env_manager=dict(type='subprocess'),
6464
policy=dict(type='sql'),
65+
reward_model=dict(type='trex'),
6566
)
6667
pong_trex_sql_create_config = EasyDict(pong_trex_sql_create_config)
6768
create_config = pong_trex_sql_create_config
@@ -73,12 +74,12 @@
7374
import argparse
7475
import torch
7576
from ding.entry import trex_collecting_data
76-
from ding.entry import serial_pipeline_preference_based_irl
77+
from ding.entry import serial_pipeline_reward_model_offpolicy
7778
parser = argparse.ArgumentParser()
7879
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
7980
parser.add_argument('--seed', type=int, default=0)
8081
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
8182
args = parser.parse_args()
8283
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
8384
trex_collecting_data(args)
84-
serial_pipeline_preference_based_irl((main_config, create_config))
85+
serial_pipeline_reward_model_offpolicy((main_config, create_config), pretrain_reward=True, cooptrain_reward=False)

dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
checkpoint_step=100,
2222
learning_rate=1e-5,
2323
update_per_collect=1,
24-
expert_model_path='abs model path',
24+
expert_model_path='qbert_dqn_seed0',
2525
hidden_size_list=[512, 64, 1],
2626
obs_shape=[4, 84, 84],
2727
action_shape=6,
@@ -64,6 +64,7 @@
6464
),
6565
env_manager=dict(type='base'),
6666
policy=dict(type='dqn'),
67+
reward_model=dict(type='trex'),
6768
)
6869
qbert_trex_dqn_create_config = EasyDict(qbert_trex_dqn_create_config)
6970
create_config = qbert_trex_dqn_create_config
@@ -76,7 +77,7 @@
7677
import argparse
7778
import torch
7879
from ding.entry import trex_collecting_data
79-
from ding.entry import serial_pipeline_reward_model_trex
80+
from ding.entry import serial_pipeline_reward_model_offpolicy
8081

8182
parser = argparse.ArgumentParser()
8283
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
@@ -85,4 +86,4 @@
8586
args = parser.parse_args()
8687
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
8788
trex_collecting_data(args)
88-
serial_pipeline_reward_model_trex((main_config, create_config))
89+
serial_pipeline_reward_model_offpolicy((main_config, create_config), pretrain_reward=True, cooptrain_reward=False)

dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
checkpoint_step=100,
2222
learning_rate=1e-5,
2323
update_per_collect=1,
24-
expert_model_path='abs model path',
24+
expert_model_path='qbert_ppo_seed0',
2525
hidden_size_list=[512, 64, 1],
2626
obs_shape=[4, 84, 84],
2727
action_shape=6,
@@ -71,6 +71,7 @@
7171
),
7272
env_manager=dict(type='subprocess'),
7373
policy=dict(type='ppo_offpolicy'),
74+
reward_model=dict(type='trex'),
7475
)
7576
create_config = EasyDict(qbert_trex_ppo_create_config)
7677

@@ -82,7 +83,7 @@
8283
import argparse
8384
import torch
8485
from ding.entry import trex_collecting_data
85-
from ding.entry import serial_pipeline_reward_model_trex
86+
from ding.entry import serial_pipeline_reward_model_offpolicy
8687

8788
parser = argparse.ArgumentParser()
8889
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
@@ -91,4 +92,4 @@
9192
args = parser.parse_args()
9293
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
9394
trex_collecting_data(args)
94-
serial_pipeline_reward_model_trex((main_config, create_config))
95+
serial_pipeline_reward_model_offpolicy((main_config, create_config), pretrain_reward=True, cooptrain_reward=False)

dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
),
1616
reward_model=dict(
1717
type='trex',
18+
exp_name='spaceinvaders_trex_dqn_seed0',
1819
min_snippet_length=50,
1920
max_snippet_length=100,
2021
checkpoint_min=10000,
@@ -28,17 +29,10 @@
2829
# Absolute path is recommended.
2930
# In DI-engine, it is ``exp_name``.
3031
# For example, if you want to use dqn to generate demos, you can use ``spaceinvaders_dqn``
31-
expert_model_path='model_path_placeholder',
32-
# path to save reward model
33-
# Users should add their own model path here.
34-
# Absolute path is recommended.
35-
# For example, if you use ``spaceinvaders_drex``, then the reward model will be saved in this directory.
36-
reward_model_path='model_path_placeholder + ./spaceinvaders.params',
37-
# path to save generated observations.
38-
# Users should add their own model path here.
39-
# Absolute path is recommended.
40-
# For example, if you use ``spaceinvaders_drex``, then all the generated data will be saved in this directory.
41-
offline_data_path='data_path_placeholder',
32+
expert_model_path='spaceinvaders_dqn_seed0',
33+
hidden_size_list=[512, 64, 1],
34+
obs_shape=[4, 84, 84],
35+
action_shape=6,
4236
),
4337
policy=dict(
4438
cuda=True,
@@ -78,6 +72,7 @@
7872
),
7973
env_manager=dict(type='subprocess'),
8074
policy=dict(type='dqn'),
75+
reward_model=dict(type='trex'),
8176
)
8277
spaceinvaders_trex_dqn_create_config = EasyDict(spaceinvaders_trex_dqn_create_config)
8378
create_config = spaceinvaders_trex_dqn_create_config
@@ -89,12 +84,12 @@
8984
import argparse
9085
import torch
9186
from ding.entry import trex_collecting_data
92-
from ding.entry import serial_pipeline_trex
87+
from ding.entry import serial_pipeline_reward_model_offpolicy
9388
parser = argparse.ArgumentParser()
9489
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
9590
parser.add_argument('--seed', type=int, default=0)
9691
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
9792
args = parser.parse_args()
9893
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
9994
trex_collecting_data(args)
100-
serial_pipeline_trex([main_config, create_config])
95+
serial_pipeline_reward_model_offpolicy([main_config, create_config], pretrain_reward=True, cooptrain_reward=False)

dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
),
1616
reward_model=dict(
1717
type='trex',
18+
exp_name='spaceinvaders_trex_offppo_seed0',
1819
min_snippet_length=30,
1920
max_snippet_length=100,
2021
checkpoint_min=0,
@@ -27,17 +28,10 @@
2728
# Absolute path is recommended.
2829
# In DI-engine, it is ``exp_name``.
2930
# For example, if you want to use dqn to generate demos, you can use ``spaceinvaders_dqn``
30-
expert_model_path='model_path_placeholder',
31-
# path to save reward model
32-
# Users should add their own model path here.
33-
# Absolute path is recommended.
34-
# For example, if you use ``spaceinvaders_drex``, then the reward model will be saved in this directory.
35-
reward_model_path='model_path_placeholder + ./spaceinvaders.params',
36-
# path to save generated observations.
37-
# Users should add their own model path here.
38-
# Absolute path is recommended.
39-
# For example, if you use ``spaceinvaders_drex``, then all the generated data will be saved in this directory.
40-
offline_data_path='data_path_placeholder',
31+
expert_model_path='spaceinvaders_ppo_seed0',
32+
hidden_size_list=[512, 64, 1],
33+
obs_shape=[4, 84, 84],
34+
action_shape=6,
4135
),
4236
policy=dict(
4337
cuda=True,
@@ -85,6 +79,7 @@
8579
),
8680
env_manager=dict(type='subprocess'),
8781
policy=dict(type='ppo_offpolicy'),
82+
reward_model=dict(type='trex'),
8883
)
8984
spaceinvaders_trex_ppo_create_config = EasyDict(spaceinvaders_trex_ppo_create_config)
9085
create_config = spaceinvaders_trex_ppo_create_config
@@ -96,12 +91,12 @@
9691
import argparse
9792
import torch
9893
from ding.entry import trex_collecting_data
99-
from ding.entry import serial_pipeline_trex
94+
from ding.entry import serial_pipeline_reward_model_offpolicy
10095
parser = argparse.ArgumentParser()
10196
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
10297
parser.add_argument('--seed', type=int, default=0)
10398
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
10499
args = parser.parse_args()
105100
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
106101
trex_collecting_data(args)
107-
serial_pipeline_trex([main_config, create_config])
102+
serial_pipeline_reward_model_offpolicy([main_config, create_config], pretrain_reward=True, cooptrain_reward=False)

dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# Users should add their own model path here. Model path should lead to a model.
2727
# Absolute path is recommended.
2828
# In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
29-
expert_model_path='model_path_placeholder',
29+
expert_model_path='lunarlander_dqn_seed0',
3030
hidden_size_list=[512, 64, 1],
3131
obs_shape=8,
3232
action_shape=4,
@@ -85,6 +85,7 @@
8585
),
8686
env_manager=dict(type='subprocess'),
8787
policy=dict(type='dqn'),
88+
reward_model=dict(type='trex'),
8889
)
8990
lunarlander_trex_dqn_create_config = EasyDict(lunarlander_trex_dqn_create_config)
9091
create_config = lunarlander_trex_dqn_create_config
@@ -96,12 +97,12 @@
9697
import argparse
9798
import torch
9899
from ding.entry import trex_collecting_data
99-
from ding.entry import serial_pipeline_trex
100+
from ding.entry import serial_pipeline_reward_model_offpolicy
100101
parser = argparse.ArgumentParser()
101102
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
102103
parser.add_argument('--seed', type=int, default=0)
103104
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
104105
args = parser.parse_args()
105106
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
106107
trex_collecting_data(args)
107-
serial_pipeline_trex([main_config, create_config])
108+
serial_pipeline_reward_model_offpolicy([main_config, create_config], pretrain_reward=True, cooptrain_reward=False)

dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# Absolute path is recommended.
2424
# In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
2525
# However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
26-
expert_model_path='model_path_placeholder',
26+
expert_model_path='lunarlander_offppo_seed0',
2727
hidden_size_list=[512, 64, 1],
2828
obs_shape=8,
2929
action_shape=4,
@@ -73,12 +73,12 @@
7373
import argparse
7474
import torch
7575
from ding.entry import trex_collecting_data
76-
from ding.entry import serial_pipeline_trex
76+
from ding.entry import serial_pipeline_reward_model_offpolicy
7777
parser = argparse.ArgumentParser()
7878
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
7979
parser.add_argument('--seed', type=int, default=0)
8080
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
8181
args = parser.parse_args()
8282
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
8383
trex_collecting_data(args)
84-
serial_pipeline_trex([main_config, create_config])
84+
serial_pipeline_reward_model_offpolicy([main_config, create_config], pretrain_reward=True, cooptrain_reward=False)

dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
),
1111
reward_model=dict(
1212
type='trex',
13+
exp_name='cartpole_trex_dqn_seed0',
1314
min_snippet_length=5,
1415
max_snippet_length=100,
1516
checkpoint_min=0,
@@ -61,6 +62,7 @@
6162
),
6263
env_manager=dict(type='base'),
6364
policy=dict(type='dqn'),
65+
reward_model=dict(type='trex'),
6466
)
6567
cartpole_trex_dqn_create_config = EasyDict(cartpole_trex_dqn_create_config)
6668
create_config = cartpole_trex_dqn_create_config
@@ -69,15 +71,17 @@
6971
# Users should first run ``cartpole_dqn_config.py`` to save models (or checkpoints).
7072
# Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
7173
# where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
74+
# example of running this file:
75+
# python cartpole_trex_dqn_config.py --cfg cartpole_trex_dqn_config.py --seed 0 --device cpu
7276
import argparse
7377
import torch
7478
from ding.entry import trex_collecting_data
75-
from ding.entry import serial_pipeline_reward_model_trex
79+
from ding.entry import serial_pipeline_reward_model_offpolicy
7680
parser = argparse.ArgumentParser()
7781
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
7882
parser.add_argument('--seed', type=int, default=0)
7983
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
8084
args = parser.parse_args()
8185
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
8286
trex_collecting_data(args)
83-
serial_pipeline_reward_model_trex((main_config, create_config))
87+
serial_pipeline_reward_model_offpolicy((main_config, create_config),pretrain_reward=True, cooptrain_reward=False)

dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
checkpoint_step=100,
1919
learning_rate=1e-5,
2020
update_per_collect=1,
21-
expert_model_path='abs model path',
21+
expert_model_path='cartpole_ppo_seed0', # expert model experiment directory path
2222
hidden_size_list=[512, 64, 1],
2323
obs_shape=4,
2424
action_shape=2,
@@ -68,15 +68,17 @@
6868
# Users should first run ``cartpole_offppo_config.py`` to save models (or checkpoints).
6969
# Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
7070
# where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
71+
# example:
72+
# python cartpole_trex_offppo_config.py --cfg cartpole_trex_offppo_config.py --seed 0 --device cpu
7173
import argparse
7274
import torch
7375
from ding.entry import trex_collecting_data
74-
from ding.entry import serial_pipeline_reward_model_trex
76+
from ding.entry import serial_pipeline_reward_model_offpolicy
7577
parser = argparse.ArgumentParser()
7678
parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
7779
parser.add_argument('--seed', type=int, default=0)
7880
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
7981
args = parser.parse_args()
8082
# The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
8183
trex_collecting_data(args)
82-
serial_pipeline_reward_model_trex((main_config, create_config))
84+
serial_pipeline_reward_model_offpolicy((main_config, create_config), pretrain_reward=True, cooptrain_reward=False)

0 commit comments

Comments
 (0)