-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlunar-lander-eval.py
More file actions
65 lines (52 loc) · 2.17 KB
/
lunar-lander-eval.py
File metadata and controls
65 lines (52 loc) · 2.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gymnasium as gym
from sb3_contrib import RecurrentPPO
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import SubprocVecEnv, VecFrameStack, VecNormalize
from argparse import ArgumentParser
from wrappers.exclude import ExcludeObservationsWrapper
parser = ArgumentParser()
parser.add_argument("--try", type=int, default=1)
parser.add_argument("--frame-stack", type=int, default=0, required=False)
parser.add_argument("--full", action="store_true", required=False, default=False)
parser.add_argument("--lstm", action="store_true", required=False, default=False)
args = vars(parser.parse_args())
ENV_SLUG = "lunar-lander"
FRAME_STACK = args["frame_stack"]
TRY = args["try"]
FULL = args["full"]
LSTM = args["lstm"]
if not FULL and not LSTM and FRAME_STACK == 0:
raise ValueError("You must specify --full, --lstm or --frames-stack")
if FRAME_STACK != 0:
if LSTM or FULL:
raise ValueError("Frame stack cannot be used with --lstm or --full")
NAME = f"FS-{FRAME_STACK}"
else:
if LSTM:
if FULL:
raise ValueError("--lstm and --full cannot be used together")
NAME = f"LSTM"
else:
NAME = f"FULL"
def make_env():
env = gym.make("LunarLander-v3", continuous=True, disable_env_checker=True)
if not FULL:
env = ExcludeObservationsWrapper(env, [2, 3, 5])
return env
def wrap(env):
if FRAME_STACK != 0:
env = VecFrameStack(env, n_stack=FRAME_STACK)
return env
if __name__ == "__main__":
eval_env = make_vec_env(make_env, n_envs=10, vec_env_cls=SubprocVecEnv)
eval_env = wrap(eval_env)
eval_env = VecNormalize.load(f"vec_normalize/{ENV_SLUG}/{NAME}_{TRY}", venv=eval_env)
if LSTM:
model = RecurrentPPO.load(f"best/{ENV_SLUG}/{NAME}_{TRY}/best_model.zip")
else:
model = PPO.load(f"best/{ENV_SLUG}/{NAME}_{TRY}/best_model.zip", device="cpu")
print("Evaluating...")
mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=100, deterministic=False)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")