Skip to content

Commit 594d619

Browse files
committed
add gail to new reward entry
1 parent 0c48c08 commit 594d619

File tree

9 files changed

+90
-350
lines changed

9 files changed

+90
-350
lines changed

ding/entry/serial_entry_gail.py

Lines changed: 0 additions & 170 deletions
This file was deleted.

dizoo/atari/config/serial/pong/pong_gail_dqn_config.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
3232
# Absolute path is recommended.
3333
# In DI-engine, it is usually located in ``exp_name`` directory
34-
# e.g. 'exp_name/expert_data.pkl'
34+
# e.g. 'exp_name'
3535
data_path='data_path_placeholder',
3636
),
3737
policy=dict(
@@ -80,13 +80,17 @@
8080
# or you can enter `ding -m serial_gail -c pong_gail_dqn_config.py -s 0`
8181
# then input the config you used to generate your expert model in the path mentioned above
8282
# e.g. pong_dqn_config.py
83-
from ding.entry import serial_pipeline_gail
83+
from ding.entry import serial_pipeline_reward_model_offpolicy, collect_demo_data
8484
from dizoo.atari.config.serial.pong import pong_dqn_config, pong_dqn_create_config
85-
expert_main_config = pong_dqn_config
86-
expert_create_config = pong_dqn_create_config
87-
serial_pipeline_gail(
88-
(main_config, create_config), (expert_main_config, expert_create_config),
89-
max_env_step=1000000,
90-
seed=0,
91-
collect_data=True
85+
86+
# set your expert config here
87+
expert_cfg = (pong_dqn_config, pong_dqn_create_config)
88+
expert_data_path = main_config.reward_model.data_path + '/expert_data.pkl'
89+
90+
# collect expert data
91+
collect_demo_data(
92+
expert_cfg, seed=0, expert_data_path=expert_data_path, collect_count=main_config.reward_model.collect_count
9293
)
94+
95+
# train reward model
96+
serial_pipeline_reward_model_offpolicy(main_config, create_config)

dizoo/box2d/bipedalwalker/config/bipedalwalker_gail_sac_config.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,17 @@
8787
# or you can enter `ding -m serial_gail -c bipedalwalker_sac_gail_config.py -s 0`
8888
# then input the config you used to generate your expert model in the path mentioned above
8989
# e.g. bipedalwalker_sac_config.py
90-
from ding.entry import serial_pipeline_gail
90+
from ding.entry import serial_pipeline_reward_model_offpolicy, collect_demo_data
9191
from dizoo.box2d.bipedalwalker.config import bipedalwalker_sac_config, bipedalwalker_sac_create_config
92-
expert_main_config = bipedalwalker_sac_config
93-
expert_create_config = bipedalwalker_sac_create_config
94-
serial_pipeline_gail(
95-
[main_config, create_config], [expert_main_config, expert_create_config], seed=0, collect_data=True
92+
93+
# set your expert config here
94+
expert_cfg = (bipedalwalker_sac_config, bipedalwalker_sac_create_config)
95+
expert_data_path = main_config.reward_model.data_path + '/expert_data.pkl'
96+
97+
# collect expert data
98+
collect_demo_data(
99+
expert_cfg, seed=0, expert_data_path=expert_data_path, collect_count=main_config.reward_model.collect_count
96100
)
101+
102+
# train reward model
103+
serial_pipeline_reward_model_offpolicy(main_config, create_config)

dizoo/box2d/lunarlander/config/lunarlander_gail_dqn_config.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
3030
# Absolute path is recommended.
3131
# In DI-engine, it is usually located in ``exp_name`` directory
32-
# e.g. 'exp_name/expert_data.pkl'
32+
# e.g. 'exp_name'
3333
data_path='data_path_placeholder',
3434
),
3535
policy=dict(
@@ -96,13 +96,17 @@
9696
# or you can enter `ding -m serial_gail -c lunarlander_dqn_gail_config.py -s 0`
9797
# then input the config you used to generate your expert model in the path mentioned above
9898
# e.g. lunarlander_dqn_config.py
99-
from ding.entry import serial_pipeline_gail
99+
from ding.entry import serial_pipeline_reward_model_offpolicy, collect_demo_data
100100
from dizoo.box2d.lunarlander.config import lunarlander_dqn_config, lunarlander_dqn_create_config
101-
expert_main_config = lunarlander_dqn_config
102-
expert_create_config = lunarlander_dqn_create_config
103-
serial_pipeline_gail(
104-
[main_config, create_config], [expert_main_config, expert_create_config],
105-
max_env_step=1000000,
106-
seed=0,
107-
collect_data=True
101+
102+
# set your expert config here
103+
expert_cfg = (lunarlander_dqn_config, lunarlander_dqn_create_config)
104+
expert_data_path = main_config.reward_model.data_path + '/expert_data.pkl'
105+
106+
# collect expert data
107+
collect_demo_data(
108+
expert_cfg, seed=0, expert_data_path=expert_data_path, collect_count=main_config.reward_model.collect_count
108109
)
110+
111+
# train reward model
112+
serial_pipeline_reward_model_offpolicy(main_config, create_config)

dizoo/classic_control/cartpole/config/cartpole_dqn_gail_config.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
# In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
2121
# If collect_data is True, we will use this expert_model_path to collect expert data first, rather than we
2222
# will load data directly from user-defined data_path
23-
expert_model_path='model_path_placeholder',
23+
# data_path is the path to store expert policy data, which is used to train reward model
24+
# so in general, data_path is the same as expert exp name
25+
expert_model_path='cartpole_dqn_seed0/ckpt/ckpt_best.pth.tar',
26+
data_path='cartpole_dqn_seed0',
2427
collect_count=1000,
2528
),
2629
policy=dict(
@@ -68,13 +71,22 @@
6871
# or you can enter `ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0`
6972
# then input the config you used to generate your expert model in the path mentioned above
7073
# e.g. cartpole_dqn_config.py
71-
from ding.entry import serial_pipeline_gail
74+
from ding.entry import serial_pipeline_reward_model_offpolicy, collect_demo_data
7275
from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config
76+
77+
# set expert config from policy config in dizoo
78+
expert_cfg = (cartpole_dqn_config, cartpole_dqn_create_config)
7379
expert_main_config = cartpole_dqn_config
74-
expert_create_config = cartpole_dqn_create_config
75-
serial_pipeline_gail(
76-
(main_config, create_config), (expert_main_config, expert_create_config),
77-
max_env_step=1000000,
80+
expert_data_path = main_config.reward_model.data_path + '/expert_data.pkl'
81+
82+
# collect expert data
83+
collect_demo_data(
84+
expert_cfg,
7885
seed=0,
79-
collect_data=True
86+
state_dict_path=main_config.reward_model.expert_model_path,
87+
expert_data_path=expert_data_path,
88+
collect_count=main_config.reward_model.collect_count
8089
)
90+
91+
# train reward model
92+
serial_pipeline_reward_model_offpolicy((main_config, create_config))

0 commit comments

Comments
 (0)