Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
IS_NOT_WINDOWS = os.name != "nt"

PARALLEL_REQUIRE = ["ray[debug,tune]~=2.0.0"]
ATARI_REQUIRE = [
IMAGE_ENV_REQUIRE = [
"opencv-python",
"ale-py==0.7.4",
"pillow",
"autorom[accept-rom-license]~=0.4.2",
"procgen==0.10.7",
"gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", # noqa: E501
]
PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else []
STABLE_BASELINES3 = "stable-baselines3>=1.6.1"
Expand Down Expand Up @@ -61,7 +63,7 @@
"pre-commit>=2.20.0",
]
+ PARALLEL_REQUIRE
+ ATARI_REQUIRE
+ IMAGE_ENV_REQUIRE
+ PYTYPE
)
DOCS_REQUIRE = [
Expand All @@ -74,7 +76,7 @@
"sphinx-github-changelog~=1.2.0",
"myst-nb==0.16.0",
"ipykernel~=6.15.2",
] + ATARI_REQUIRE
] + IMAGE_ENV_REQUIRE


def get_readme() -> str:
Expand Down Expand Up @@ -231,7 +233,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
"mujoco": [
"gym[classic_control,mujoco]" + GYM_VERSION_SPECIFIER,
],
"atari": ATARI_REQUIRE,
"image_envs": IMAGE_ENV_REQUIRE,
},
entry_points={
"console_scripts": [
Expand Down
8 changes: 8 additions & 0 deletions src/imitation/scripts/common/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def sac():
locals() # quieten flake8


@rl_ingredient.named_config
def procgen_default():
# copying the hyperparams used in "Goal Misgeneralization in Deep
# Reinforcement Learning"
rl_cls = sb3.PPO
rl_kwargs = dict(gamma=0.999, learning_rate=5e-4, batch_size=2048, ent_coef=0.01, n_steps=1024)


def _maybe_add_relabel_buffer(
rl_kwargs: Dict[str, Any],
relabel_reward_fn: Optional[RewardFn] = None,
Expand Down
83 changes: 83 additions & 0 deletions src/imitation/scripts/config/eval_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,89 @@ def seals_walker():
common = dict(env_name="seals/Walker2d-v0")


# Procgen configs


@eval_policy_ex.named_config
def coinrun():
common = dict(env_name="procgen:procgen-coinrun-v0")


@eval_policy_ex.named_config
def maze():
common = dict(env_name="procgen:procgen-maze-v0")


@eval_policy_ex.named_config
def bigfish():
common = dict(env_name="procgen:procgen-bigfish-v0")


@eval_policy_ex.named_config
def bossfight():
common = dict(env_name="procgen:procgen-bossfight-v0")


@eval_policy_ex.named_config
def caveflyer():
common = dict(env_name="procgen:procgen-caveflyer-v0")


@eval_policy_ex.named_config
def chaser():
common = dict(env_name="procgen:procgen-chaser-v0")


@eval_policy_ex.named_config
def climber():
common = dict(env_name="procgen:procgen-climber-v0")


@eval_policy_ex.named_config
def dodgeball():
common = dict(env_name="procgen:procgen-dodgeball-v0")


@eval_policy_ex.named_config
def fruitbot():
common = dict(env_name="procgen:procgen-fruitbot-v0")


@eval_policy_ex.named_config
def heist():
common = dict(env_name="procgen:procgen-heist-v0")


@eval_policy_ex.named_config
def jumper():
common = dict(env_name="procgen:procgen-jumper-v0")


@eval_policy_ex.named_config
def leaper():
common = dict(env_name="procgen:procgen-leaper-v0")


@eval_policy_ex.named_config
def miner():
common = dict(env_name="procgen:procgen-miner-v0")


@eval_policy_ex.named_config
def ninja():
common = dict(env_name="procgen:procgen-ninja-v0")


@eval_policy_ex.named_config
def plunder():
common = dict(env_name="procgen:procgen-plunder-v0")


@eval_policy_ex.named_config
def starpilot():
common = dict(env_name="procgen:procgen-starpilot-v0")


@eval_policy_ex.named_config
def fast():
common = dict(env_name="seals/CartPole-v0", num_vec=1, parallel=False)
Expand Down
99 changes: 99 additions & 0 deletions src/imitation/scripts/config/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,105 @@ def seals_walker():
common = dict(env_name="seals/Walker2d-v0")


# Procgen configs


@train_rl_ex.named_config
def coinrun():
common = dict(env_name="procgen:procgen-coinrun-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def maze():
common = dict(env_name="procgen:procgen-maze-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def bigfish():
common = dict(env_name="procgen:procgen-bigfish-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def bossfight():
common = dict(env_name="procgen:procgen-bossfight-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def caveflyer():
common = dict(env_name="procgen:procgen-caveflyer-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def chaser():
common = dict(env_name="procgen:procgen-chaser-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def climber():
common = dict(env_name="procgen:procgen-climber-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def dodgeball():
common = dict(env_name="procgen:procgen-dodgeball-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def fruitbot():
common = dict(env_name="procgen:procgen-fruitbot-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def heist():
common = dict(env_name="procgen:procgen-heist-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def jumper():
common = dict(env_name="procgen:procgen-jumper-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def leaper():
common = dict(env_name="procgen:procgen-leaper-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def miner():
common = dict(env_name="procgen:procgen-miner-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def ninja():
common = dict(env_name="procgen:procgen-ninja-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def plunder():
common = dict(env_name="procgen:procgen-plunder-v0")
total_timesteps = int(2e8)


@train_rl_ex.named_config
def starpilot():
common = dict(env_name="procgen:procgen-starpilot-v0")
total_timesteps = int(2e8)


# Debug configs


Expand Down
23 changes: 23 additions & 0 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,29 @@ def test_train_rl_cnn_policy(tmpdir: str, rng):
assert isinstance(run.result, dict)


def test_train_rl_coinrun(tmpdir: str, rng):
venv = util.make_vec_env(
"procgen:procgen-coinrun-v0",
n_envs=1,
parallel=False,
rng=rng,
)
net = reward_nets.CnnRewardNet(venv.observation_space, venv.action_space)
tmppath = os.path.join(tmpdir, "reward.pt")
th.save(net, tmppath)

log_dir_data = os.path.join(tmpdir, "train_rl")
run = train_rl.train_rl_ex.run(
named_configs=["train.cnn_policy"] + ALGO_FAST_CONFIGS["rl"] + ["coinrun"],
config_updates=dict(
common=dict(log_dir=log_dir_data),
reward_path=tmppath,
),
)
assert run.status == "COMPLETED"
assert isinstance(run.result, dict)


PARALLEL_CONFIG_UPDATES = [
dict(
sacred_ex_name="train_rl",
Expand Down