Skip to content

Commit 72ab7aa

Browse files
committed
Running make_vec_envs with the priciing env
1 parent 2402f37 commit 72ab7aa

File tree

11 files changed

+324
-251
lines changed

11 files changed

+324
-251
lines changed

ddopai/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,6 +1684,8 @@
16841684
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
16851685
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.reset_task': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.reset_task',
16861686
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
1687+
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.seed': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.seed',
1688+
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
16871689
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.step': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.step',
16881690
'ddopai/meta_learning/environments/pricing_env/pricing_env.py'),
16891691
'ddopai.meta_learning.environments.pricing_env.pricing_env.PricingEnv.visualise_behaviour': ( '50_meta_learning/53_environments/01_pricing_env/pricing_env.html#pricingenv.visualise_behaviour',

ddopai/meta_learning/environments/env_utils/vec_env/dummy_vec_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, env_fns):
3939
self.keys, shapes, dtypes = obs_space_info(obs_space)
4040

4141
self.buf_obs = {k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys}
42-
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
42+
self.buf_dones = np.zeros((self.num_envs,), dtype=bool)
4343
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
4444
self.buf_infos = [{} for _ in range(self.num_envs)]
4545
self.actions = None

ddopai/meta_learning/environments/parallel_envs.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/50_meta_learning/53_environments/20_parralel_envs.ipynb.
22

33
# %% auto 0
4-
__all__ = ['make_env', 'make_vec_envs', 'VecPyTorch']
4+
__all__ = ['make_env', 'VecPyTorch', 'make_vec_envs']
55

66
# %% ../../../nbs/50_meta_learning/53_environments/20_parralel_envs.ipynb 1
77
import gym
@@ -31,40 +31,6 @@ def _thunk():
3131
return _thunk
3232

3333
# %% ../../../nbs/50_meta_learning/53_environments/20_parralel_envs.ipynb 3
34-
def make_vec_envs(env_name, seed, num_processes, gamma,
35-
device, episodes_per_task,
36-
normalise_rew, ret_rms,
37-
args, mode='train',
38-
rank_offset=0,
39-
**kwargs):
40-
"""
41-
:param ret_rms: running return and std for rewards
42-
"""
43-
envs = [make_env(env_id=env_name, seed=seed, rank=rank_offset + i,
44-
episodes_per_task=episodes_per_task,
45-
mode=mode, args=args, **kwargs)
46-
for i in range(num_processes)]
47-
48-
if len(envs) > 1:
49-
envs = SubprocVecEnv(envs)
50-
else:
51-
envs = DummyVecEnv(envs)
52-
53-
if len(envs.observation_space.shape) == 1:
54-
if ret_rms is not None:
55-
# copy this here to make sure the new envs don't change the return stats where this comes from
56-
ret_rms = copy.copy(ret_rms)
57-
58-
envs = VecNormalize(envs,
59-
normalise_rew=normalise_rew, ret_rms=ret_rms,
60-
gamma=0.99 if gamma is None else gamma,
61-
cliprew=args.norm_rew_clip_param if 'norm_rew_clip_param' in vars(args) else 10.0)
62-
63-
envs = VecPyTorch(envs, device)
64-
65-
return envs
66-
67-
# %% ../../../nbs/50_meta_learning/53_environments/20_parralel_envs.ipynb 4
6834
class VecPyTorch(VecEnvWrapper):
6935
def __init__(self, venv, device):
7036
"""Return only every `skip`-th frame"""
@@ -128,3 +94,37 @@ def hooked(*args, **kwargs):
12894
return hooked
12995
else:
13096
return orig_attr
97+
98+
# %% ../../../nbs/50_meta_learning/53_environments/20_parralel_envs.ipynb 4
99+
def make_vec_envs(env_name, seed, num_processes, gamma,
100+
device, episodes_per_task,
101+
normalise_rew, ret_rms,
102+
args, mode='train',
103+
rank_offset=0,
104+
**kwargs):
105+
"""
106+
:param ret_rms: running return and std for rewards
107+
"""
108+
envs = [make_env(env_id=env_name, seed=seed, rank=rank_offset + i,
109+
episodes_per_task=episodes_per_task,
110+
mode=mode, args=args, **kwargs)
111+
for i in range(num_processes)]
112+
113+
if len(envs) > 1:
114+
envs = SubprocVecEnv(envs)
115+
else:
116+
envs = DummyVecEnv(envs)
117+
118+
if len(envs.observation_space.shape) == 1:
119+
if ret_rms is not None:
120+
# copy this here to make sure the new envs don't change the return stats where this comes from
121+
ret_rms = copy.copy(ret_rms)
122+
123+
envs = VecNormalize(envs,
124+
normalise_rew=normalise_rew, ret_rms=ret_rms,
125+
gamma=0.99 if gamma is None else gamma,
126+
cliprew=args.norm_rew_clip_param if 'norm_rew_clip_param' in vars(args) else 10.0)
127+
128+
envs = VecPyTorch(envs, device)
129+
130+
return envs

ddopai/meta_learning/environments/pricing_env/pricing_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self,
113113
low=-self._BIG, high=self._BIG,
114114
shape=(self.task_dim,), dtype=np.float32
115115
)
116-
116+
self.seed()
117117
# -------- set latent task & episode ----------------------------------
118118
self.reset_task(task)
119119
self.reset()
@@ -192,7 +192,9 @@ def step(self, action):
192192
}
193193
return obs, reward, done, info
194194

195-
195+
def seed(self, seed=None):
196+
self.np_random, seed = gym.utils.seeding.np_random(seed)
197+
return [seed]
196198
# ===================================================================== #
197199
# internal helpers #
198200
# ===================================================================== #

ddopai/meta_learning/environments/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self,
4444
if not hasattr(self.env.unwrapped, 'num_states'):
4545
self.env.unwrapped.num_states = None
4646
if not hasattr(self.env.unwrapped, '_max_episode_steps'): # Meta-World ML10/ML45
47-
self.env.unwrapped._max_episode_steps = env.max_path_length
47+
self.env.unwrapped._max_episode_steps = env.horizon
4848

4949
if episodes_per_task > 1:
5050
self.add_done_info = True

ddopai/meta_learning/utils/helpers.py

Lines changed: 35 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,93 +18,45 @@
1818
import torch
1919
import torch.nn as nn
2020
from torch.nn import functional as F
21-
21+
from ..environments.pricing_env.pricing_env import PricingEnv
22+
from ..environments.wrappers import PrevActRewWrapper
2223
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2324

2425
# %% ../../../nbs/50_meta_learning/50_utils/20_helpers.ipynb 2
25-
def make_env(args, mode='train', train_task_override=None, **kwargs):
26-
env_id = args.env_name
27-
28-
# NEW ENV: METAWORLD
29-
if env_id.startswith('metaworld'):
30-
31-
if args.mw_version == 1:
32-
from environments.metaworld import metaworld
33-
elif args.mw_version == 2:
34-
from environments.metaworld_v2 import metaworld
35-
36-
env_type = 'metaworld'
37-
38-
# --- ML1 ---
39-
# import the right meta-world-environment
40-
if env_id == 'metaworld_ml1':
41-
env_name = f'{args.ml1_type}-v{args.mw_version}'
42-
mworld = metaworld.ML1(env_name) # Construct the benchmark, sampling tasks
43-
# set up train/test env
44-
if mode == 'train':
45-
env = mworld.train_classes[env_name]()
46-
if train_task_override is not None:
47-
env.reset_task = lambda: env.set_task(random.choice(train_task_override))
48-
else:
49-
env.reset_task = lambda: env.set_task(random.choice(mworld.train_tasks))
50-
elif mode == 'test':
51-
env = mworld.test_classes[env_name]()
52-
env.reset_task = lambda: env.set_task(random.choice(mworld.test_tasks))
53-
54-
# --- ML10 ---
55-
elif env_id == 'metaworld_ml10':
56-
ml10 = metaworld.ML10()
57-
# if mode == 'train':
58-
# n_envs = 10
59-
# elif mode == 'test':
60-
# n_envs = 5
61-
# else:
62-
# raise ValueError
63-
64-
# n_tasks = n_envs * 1 # Leo: This ensures 1 env of each is sampled.
65-
from environments.garage.experiment.task_sampler import MetaWorldTaskSampler # Can't do this at top since it breaks MuJoCo131 needed for Walker
66-
task_sampler = MetaWorldTaskSampler(ml10,
67-
mode, # train or test
68-
wrapper=None,
69-
# lambda env, _: normalize(env), # TODO: not sure if we should use this
70-
add_env_onehot=False)
71-
# envs = [env_up() for env_up in task_sampler.sample(n_tasks)]
72-
from environments.mw_wrapper import MetaWorldMultiEnvWrapper # Can't do this at top since it breaks MuJoCo131 needed for Walker
73-
env = MetaWorldMultiEnvWrapper(task_sampler,
74-
n_tasks_train=10,
75-
n_tasks_test=5, # needed to make one-hot ids
76-
mode='vanilla',
77-
train=(mode=='train'))
78-
else:
79-
raise ValueError
80-
env._max_episode_steps = env.max_path_length
81-
elif env_id.startswith('T-') or env_id.startswith('MC-'):
82-
env = gym.make(env_id, **kwargs)
83-
env_type = "Maze"
84-
# OTHERWISE WE ASSUME ITS A GYM ENV
85-
else:
86-
env_type = 'gym'
87-
if args is not None and args.env_name == 'RoomNavi-v0':
88-
env = gym.make(env_id,
89-
num_cells=args.num_cells,
90-
corridor_len=args.corridor_len,
91-
num_steps=args.horizon,
92-
**kwargs)
93-
if args is not None and args.env_name == 'TreasureHunt-v0':
94-
env = gym.make(env_id,
95-
max_episode_steps=args.max_episode_steps,
96-
mountain_height=args.mountain_height,
97-
treasure_reward=args.treasure_reward,
98-
timestep_penalty=args.timestep_penalty,
99-
**kwargs)
100-
elif args is not None and args.env_name == 'AntGoalSparse-v0':
101-
env = gym.make(env_id,
102-
level=args.level,
103-
**kwargs)
104-
else:
105-
env = gym.make(env_id, **kwargs)
26+
# --------------------------------------------------------------------
27+
def make_env(args, mode='train', **kwargs):
28+
"""
29+
Create **one** PricingEnv with the requested wrappers.
30+
31+
Parameters
32+
----------
33+
args : argparse.Namespace – needs at least
34+
* env_name (should be 'Pricing-v0' or similar)
35+
* pricing_kwargs (dict forwarded to PricingEnv)
36+
* ar_in_state (bool, adds PrevActRewWrapper)
37+
* max_episode_length (int, TimeLimit wrapper)
38+
mode : 'train' | 'test' – kept for API compatibility; ignored here.
39+
"""
40+
assert args.env_name.lower().startswith('pricing'), \
41+
"This trimmed helper only supports PricingEnv."
42+
43+
# base env --------------------------------------------------------
44+
env = PricingEnv(**args.pricing_kwargs)
45+
46+
# RL^2 needs (s_{t-1}, a_{t-1}, r_{t-1}) in the observation
47+
if args.ar_in_state:
48+
env = PrevActRewWrapper(env)
49+
50+
51+
# optional obs normalisation (reuse Hyper’s wrapper if desired)
52+
if getattr(args, "norm_obs", False):
53+
from ddopai.meta_learning.environments.wrappers import NormaliseObservations
54+
env = NormaliseObservations(env, clip=10.0, eps=1e-8,
55+
training=(mode == 'train'))
56+
57+
58+
return env, "pricing"
10659

107-
return env, env_type
10860

10961

11062
def reset_env(env, args, indices=None, state=None):

0 commit comments

Comments
 (0)