|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
20 | 20 | from torch.nn import functional as F |
21 | | - |
| 21 | +from ..environments.pricing_env.pricing_env import PricingEnv |
| 22 | +from ..environments.wrappers import PrevActRewWrapper |
22 | 23 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
23 | 24 |
|
24 | 25 | # %% ../../../nbs/50_meta_learning/50_utils/20_helpers.ipynb 2 |
25 | | -def make_env(args, mode='train', train_task_override=None, **kwargs): |
26 | | - env_id = args.env_name |
27 | | - |
28 | | - # NEW ENV: METAWORLD |
29 | | - if env_id.startswith('metaworld'): |
30 | | - |
31 | | - if args.mw_version == 1: |
32 | | - from environments.metaworld import metaworld |
33 | | - elif args.mw_version == 2: |
34 | | - from environments.metaworld_v2 import metaworld |
35 | | - |
36 | | - env_type = 'metaworld' |
37 | | - |
38 | | - # --- ML1 --- |
39 | | - # import the right meta-world-environment |
40 | | - if env_id == 'metaworld_ml1': |
41 | | - env_name = f'{args.ml1_type}-v{args.mw_version}' |
42 | | - mworld = metaworld.ML1(env_name) # Construct the benchmark, sampling tasks |
43 | | - # set up train/test env |
44 | | - if mode == 'train': |
45 | | - env = mworld.train_classes[env_name]() |
46 | | - if train_task_override is not None: |
47 | | - env.reset_task = lambda: env.set_task(random.choice(train_task_override)) |
48 | | - else: |
49 | | - env.reset_task = lambda: env.set_task(random.choice(mworld.train_tasks)) |
50 | | - elif mode == 'test': |
51 | | - env = mworld.test_classes[env_name]() |
52 | | - env.reset_task = lambda: env.set_task(random.choice(mworld.test_tasks)) |
53 | | - |
54 | | - # --- ML10 --- |
55 | | - elif env_id == 'metaworld_ml10': |
56 | | - ml10 = metaworld.ML10() |
57 | | - # if mode == 'train': |
58 | | - # n_envs = 10 |
59 | | - # elif mode == 'test': |
60 | | - # n_envs = 5 |
61 | | - # else: |
62 | | - # raise ValueError |
63 | | - |
64 | | - # n_tasks = n_envs * 1 # Leo: This ensures 1 env of each is sampled. |
65 | | - from environments.garage.experiment.task_sampler import MetaWorldTaskSampler # Can't do this at top since it breaks MuJoCo131 needed for Walker |
66 | | - task_sampler = MetaWorldTaskSampler(ml10, |
67 | | - mode, # train or test |
68 | | - wrapper=None, |
69 | | - # lambda env, _: normalize(env), # TODO: not sure if we should use this |
70 | | - add_env_onehot=False) |
71 | | - # envs = [env_up() for env_up in task_sampler.sample(n_tasks)] |
72 | | - from environments.mw_wrapper import MetaWorldMultiEnvWrapper # Can't do this at top since it breaks MuJoCo131 needed for Walker |
73 | | - env = MetaWorldMultiEnvWrapper(task_sampler, |
74 | | - n_tasks_train=10, |
75 | | - n_tasks_test=5, # needed to make one-hot ids |
76 | | - mode='vanilla', |
77 | | - train=(mode=='train')) |
78 | | - else: |
79 | | - raise ValueError |
80 | | - env._max_episode_steps = env.max_path_length |
81 | | - elif env_id.startswith('T-') or env_id.startswith('MC-'): |
82 | | - env = gym.make(env_id, **kwargs) |
83 | | - env_type = "Maze" |
84 | | - # OTHERWISE WE ASSUME ITS A GYM ENV |
85 | | - else: |
86 | | - env_type = 'gym' |
87 | | - if args is not None and args.env_name == 'RoomNavi-v0': |
88 | | - env = gym.make(env_id, |
89 | | - num_cells=args.num_cells, |
90 | | - corridor_len=args.corridor_len, |
91 | | - num_steps=args.horizon, |
92 | | - **kwargs) |
93 | | - if args is not None and args.env_name == 'TreasureHunt-v0': |
94 | | - env = gym.make(env_id, |
95 | | - max_episode_steps=args.max_episode_steps, |
96 | | - mountain_height=args.mountain_height, |
97 | | - treasure_reward=args.treasure_reward, |
98 | | - timestep_penalty=args.timestep_penalty, |
99 | | - **kwargs) |
100 | | - elif args is not None and args.env_name == 'AntGoalSparse-v0': |
101 | | - env = gym.make(env_id, |
102 | | - level=args.level, |
103 | | - **kwargs) |
104 | | - else: |
105 | | - env = gym.make(env_id, **kwargs) |
| 26 | +# -------------------------------------------------------------------- |
| 27 | +def make_env(args, mode='train', **kwargs): |
| 28 | + """ |
| 29 | + Create **one** PricingEnv with the requested wrappers. |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + args : argparse.Namespace – needs at least |
| 34 | + * env_name (should be 'Pricing-v0' or similar) |
| 35 | + * pricing_kwargs (dict forwarded to PricingEnv) |
| 36 | + * ar_in_state (bool, adds PrevActRewWrapper) |
| 37 | + * max_episode_length (int, TimeLimit wrapper) |
| 38 | + mode : 'train' | 'test' – kept for API compatibility; ignored here. |
| 39 | + """ |
| 40 | + assert args.env_name.lower().startswith('pricing'), \ |
| 41 | + "This trimmed helper only supports PricingEnv." |
| 42 | + |
| 43 | + # base env -------------------------------------------------------- |
| 44 | + env = PricingEnv(**args.pricing_kwargs) |
| 45 | + |
| 46 | + # RL^2 needs (s_{t-1}, a_{t-1}, r_{t-1}) in the observation |
| 47 | + if args.ar_in_state: |
| 48 | + env = PrevActRewWrapper(env) |
| 49 | + |
| 50 | + |
| 51 | + # optional obs normalisation (reuse Hyper’s wrapper if desired) |
| 52 | + if getattr(args, "norm_obs", False): |
| 53 | + from ddopai.meta_learning.environments.wrappers import NormaliseObservations |
| 54 | + env = NormaliseObservations(env, clip=10.0, eps=1e-8, |
| 55 | + training=(mode == 'train')) |
| 56 | + |
| 57 | + |
| 58 | + return env, "pricing" |
106 | 59 |
|
107 | | - return env, env_type |
108 | 60 |
|
109 | 61 |
|
110 | 62 | def reset_env(env, args, indices=None, state=None): |
|
0 commit comments