From 6c91610da27ea0b016477313d621facd54da5213 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 29 Nov 2022 15:31:28 -0800 Subject: [PATCH 1/5] Add support for procgen training --- setup.py | 10 ++- src/imitation/scripts/config/eval_policy.py | 83 +++++++++++++++++++++ src/imitation/scripts/config/train_rl.py | 83 +++++++++++++++++++++ 3 files changed, 172 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 151f17ea9..5d9270996 100644 --- a/setup.py +++ b/setup.py @@ -14,11 +14,13 @@ IS_NOT_WINDOWS = os.name != "nt" PARALLEL_REQUIRE = ["ray[debug,tune]~=2.0.0"] -ATARI_REQUIRE = [ +IMAGE_ENV_REQUIRE = [ "opencv-python", "ale-py==0.7.4", "pillow", "autorom[accept-rom-license]~=0.4.2", + "procgen==0.10.7", + "gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", # noqa: E501 ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] STABLE_BASELINES3 = "stable-baselines3>=1.6.1" @@ -61,7 +63,7 @@ "pre-commit>=2.20.0", ] + PARALLEL_REQUIRE - + ATARI_REQUIRE + + IMAGE_ENV_REQUIRE + PYTYPE ) DOCS_REQUIRE = [ @@ -74,7 +76,7 @@ "sphinx-github-changelog~=1.2.0", "myst-nb==0.16.0", "ipykernel~=6.15.2", -] + ATARI_REQUIRE +] + IMAGE_ENV_REQUIRE def get_readme() -> str: @@ -231,7 +233,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "mujoco": [ "gym[classic_control,mujoco]" + GYM_VERSION_SPECIFIER, ], - "atari": ATARI_REQUIRE, + "image_envs": IMAGE_ENV_REQUIRE, }, entry_points={ "console_scripts": [ diff --git a/src/imitation/scripts/config/eval_policy.py b/src/imitation/scripts/config/eval_policy.py index 9bc8e29a6..907e2fa63 100644 --- a/src/imitation/scripts/config/eval_policy.py +++ b/src/imitation/scripts/config/eval_policy.py @@ -116,6 +116,89 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Procgen configs + + +@eval_policy_ex.named_config +def coinrun(): + common = dict(env_name="procgen:procgen-coinrun-v0") + + +@eval_policy_ex.named_config +def maze(): + common = dict(env_name="procgen:procgen-maze-v0") + + +@eval_policy_ex.named_config +def bigfish(): + common = dict(env_name="procgen:procgen-bigfish-v0") + + +@eval_policy_ex.named_config +def bossfight(): + common = dict(env_name="procgen:procgen-bossfight-v0") + + +@eval_policy_ex.named_config +def caveflyer(): + common = dict(env_name="procgen:procgen-caveflyer-v0") + + +@eval_policy_ex.named_config +def chaser(): + common = dict(env_name="procgen:procgen-chaser-v0") + + +@eval_policy_ex.named_config +def climber(): + common = dict(env_name="procgen:procgen-climber-v0") + + +@eval_policy_ex.named_config +def dodgeball(): + common = dict(env_name="procgen:procgen-dodgeball-v0") + + +@eval_policy_ex.named_config +def fruitbot(): + common = dict(env_name="procgen:procgen-fruitbot-v0") + + +@eval_policy_ex.named_config +def heist(): + common = dict(env_name="procgen:procgen-heist-v0") + + +@eval_policy_ex.named_config +def jumper(): + common = dict(env_name="procgen:procgen-jumper-v0") + + +@eval_policy_ex.named_config +def leaper(): + common = dict(env_name="procgen:procgen-leaper-v0") + + +@eval_policy_ex.named_config +def miner(): + common = dict(env_name="procgen:procgen-miner-v0") + + +@eval_policy_ex.named_config +def ninja(): + common = dict(env_name="procgen:procgen-ninja-v0") + + +@eval_policy_ex.named_config +def plunder(): + common = dict(env_name="procgen:procgen-plunder-v0") + + +@eval_policy_ex.named_config +def starpilot(): + common = dict(env_name="procgen:procgen-starpilot-v0") + + @eval_policy_ex.named_config def fast(): common = dict(env_name="seals/CartPole-v0", num_vec=1, parallel=False) diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index b9ede3165..e0a2fc8b9 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -132,6 +132,89 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Procgen configs + + +@train_rl_ex.named_config +def coinrun(): + common = dict(env_name="procgen:procgen-coinrun-v0") + + +@train_rl_ex.named_config +def maze(): + common = dict(env_name="procgen:procgen-maze-v0") + + +@train_rl_ex.named_config +def bigfish(): + common = dict(env_name="procgen:procgen-bigfish-v0") + + +@train_rl_ex.named_config +def bossfight(): + common = dict(env_name="procgen:procgen-bossfight-v0") + + +@train_rl_ex.named_config +def caveflyer(): + common = dict(env_name="procgen:procgen-caveflyer-v0") + + +@train_rl_ex.named_config +def chaser(): + common = dict(env_name="procgen:procgen-chaser-v0") + + +@train_rl_ex.named_config +def climber(): + common = dict(env_name="procgen:procgen-climber-v0") + + +@train_rl_ex.named_config +def dodgeball(): + common = dict(env_name="procgen:procgen-dodgeball-v0") + + +@train_rl_ex.named_config +def fruitbot(): + common = dict(env_name="procgen:procgen-fruitbot-v0") + + +@train_rl_ex.named_config +def heist(): + common = dict(env_name="procgen:procgen-heist-v0") + + +@train_rl_ex.named_config +def jumper(): + common = dict(env_name="procgen:procgen-jumper-v0") + + +@train_rl_ex.named_config +def leaper(): + common = dict(env_name="procgen:procgen-leaper-v0") + + +@train_rl_ex.named_config +def miner(): + common = dict(env_name="procgen:procgen-miner-v0") + + +@train_rl_ex.named_config +def ninja(): + common = dict(env_name="procgen:procgen-ninja-v0") + + +@train_rl_ex.named_config +def plunder(): + common = dict(env_name="procgen:procgen-plunder-v0") + + +@train_rl_ex.named_config +def starpilot(): + common = dict(env_name="procgen:procgen-starpilot-v0") + + # Debug configs From 5a56417ae1b97a7c5bf5057bc9d21479d84836ed Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 29 Nov 2022 15:34:12 -0800 Subject: [PATCH 2/5] [empty] run code checks From bb082fcf33f8acf723eabaf3093724de7d787ce6 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 29 Nov 2022 15:41:58 -0800 Subject: [PATCH 3/5] [empty] run code checks From 1ee83a56744eb43d96f52ebc6cf71b4f1c770beb Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 29 Nov 2022 17:03:46 -0800 Subject: [PATCH 4/5] Add test for training RL on coinrun --- tests/scripts/test_scripts.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 226b6b3c2..89386b984 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -749,6 +749,29 @@ def test_train_rl_cnn_policy(tmpdir: str, rng): assert isinstance(run.result, dict) +def test_train_rl_coinrun(tmpdir: str, rng): + venv = util.make_vec_env( + "procgen:procgen-coinrun-v0", + n_envs=1, + parallel=False, + rng=rng, + ) + net = reward_nets.CnnRewardNet(venv.observation_space, venv.action_space) + tmppath = os.path.join(tmpdir, "reward.pt") + th.save(net, tmppath) + + log_dir_data = os.path.join(tmpdir, "train_rl") + run = train_rl.train_rl_ex.run( + named_configs=["train.cnn_policy"] + ALGO_FAST_CONFIGS["rl"] + ["coinrun"], + config_updates=dict( + common=dict(log_dir=log_dir_data), + reward_path=tmppath, + ), + ) + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + PARALLEL_CONFIG_UPDATES = [ dict( sacred_ex_name="train_rl", From 245e8818a4260766432727d6a1b7a14517ffcc18 Mon Sep 17 00:00:00 2001 From: Daniel Filan Date: Tue, 29 Nov 2022 18:28:27 -0800 Subject: [PATCH 5/5] Start approximating the algorithm used in goal misgen paper --- src/imitation/scripts/common/rl.py | 8 ++++++++ src/imitation/scripts/config/train_rl.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/imitation/scripts/common/rl.py b/src/imitation/scripts/common/rl.py index 2bd3759a2..c0fd01ab8 100644 --- a/src/imitation/scripts/common/rl.py +++ b/src/imitation/scripts/common/rl.py @@ -79,6 +79,14 @@ def sac(): locals() # quieten flake8 +@rl_ingredient.named_config +def procgen_default(): + # copying the hyperparams used in "Goal Misgeneralization in Deep + # Reinforcement Learning" + rl_cls = sb3.PPO + rl_kwargs = dict(gamma=0.999, learning_rate=5e-4, batch_size=2048, ent_coef=0.01, n_steps=1024) + + def _maybe_add_relabel_buffer( rl_kwargs: Dict[str, Any], relabel_reward_fn: Optional[RewardFn] = None, diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index e0a2fc8b9..3797735f2 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -138,81 +138,97 @@ def seals_walker(): @train_rl_ex.named_config def coinrun(): common = dict(env_name="procgen:procgen-coinrun-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def maze(): common = dict(env_name="procgen:procgen-maze-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def bigfish(): common = dict(env_name="procgen:procgen-bigfish-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def bossfight(): common = dict(env_name="procgen:procgen-bossfight-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def caveflyer(): common = dict(env_name="procgen:procgen-caveflyer-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def chaser(): common = dict(env_name="procgen:procgen-chaser-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def climber(): common = dict(env_name="procgen:procgen-climber-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def dodgeball(): common = dict(env_name="procgen:procgen-dodgeball-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def fruitbot(): common = dict(env_name="procgen:procgen-fruitbot-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def heist(): common = dict(env_name="procgen:procgen-heist-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def jumper(): common = dict(env_name="procgen:procgen-jumper-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def leaper(): common = dict(env_name="procgen:procgen-leaper-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def miner(): common = dict(env_name="procgen:procgen-miner-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def ninja(): common = dict(env_name="procgen:procgen-ninja-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def plunder(): common = dict(env_name="procgen:procgen-plunder-v0") + total_timesteps = int(2e8) @train_rl_ex.named_config def starpilot(): common = dict(env_name="procgen:procgen-starpilot-v0") + total_timesteps = int(2e8) # Debug configs