-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathenv_utils.py
More file actions
162 lines (125 loc) · 4.52 KB
/
env_utils.py
File metadata and controls
162 lines (125 loc) · 4.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from tqdm import tqdm
import numpy as np
from composuite.env.gym_wrapper import GymWrapper
from typing import Sequence
from abc import ABC, abstractmethod
from copy import deepcopy
import numpy as np
# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on
class CompoSuiteGymnasiumWrapper(GymWrapper):
def __init__(self, env):
super().__init__(env)
self.step_counter = 0
def reset(self):
obs = super().reset()
self.step_counter = 0
return obs, {}
def step(self, action):
obs, reward, done, info = super().step(action)
if self.step_counter == self.horizon - 1:
truncated = True
self.step_counter = 0
else:
truncated = False
self.step_counter += 1
return obs, reward, done, truncated, info
class VecEnv(ABC):
metadata = {"render.modes": ["human", "rgb_array"]}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
@abstractmethod
def reset(self, seed=None):
pass
@abstractmethod
def step_async(self, actions):
pass
@abstractmethod
def step_wait(self):
pass
@abstractmethod
def close(self):
pass
@abstractmethod
def get_attr(self, attr_name, indices=None):
pass
@abstractmethod
def set_attr(self, attr_name, value, indices=None):
pass
@abstractmethod
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
pass
def step(self, actions):
self.step_async(actions)
return self.step_wait()
def get_images(self) -> Sequence[np.ndarray]:
raise NotImplementedError
def render(self, mode: str = "human"):
raise NotImplementedError
def getattr_depth_check(self, name, already_found):
if hasattr(self, name) and already_found:
return "{0}.{1}".format(type(self).__module__, type(self).__name__)
else:
return None
def _get_indices(self, indices):
if indices is None:
indices = range(self.num_envs)
elif isinstance(indices, int):
indices = [indices]
return indices
class DummyVecEnv(VecEnv):
def __init__(self, env_fns):
self.envs = [fn() for fn in tqdm(env_fns)]
self.num_envs = len(self.envs)
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
super().__init__(
len(self.envs),
self.envs[0].observation_space,
self.envs[0].action_space,
)
def step_async(self, actions):
self.actions = actions
def step_wait(self):
return_list = [env.step(a) for (env, a) in zip(self.envs, self.actions)]
obs, rews, dones, _, infos = map(np.array, zip(*return_list))
return obs, rews, dones, infos
def seed(self, seed=None):
raise NotImplementedError
def reset(self, seed=None):
obs_list = [env.reset()[0] for env in self.envs]
return np.array(obs_list)
def close(self):
for env in self.envs:
env.close()
def get_images(self) -> Sequence[np.ndarray]:
return [env.render(mode="rgb_array") for env in self.envs]
def render(self, mode: str = "human"):
if self.num_envs == 1:
return self.envs[0].render()
else:
return super().render()
def get_attr(self, attr_name, indices=None):
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
def set_attr(self, attr_name, value, indices=None):
"""Set attribute inside vectorized environments (see base class)."""
target_envs = self._get_target_envs(indices)
for env_i in target_envs:
setattr(env_i, attr_name, value)
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [
getattr(env_i, method_name)(*method_args, **method_kwargs)
for env_i in target_envs
]
def _get_target_envs(self, indices):
indices = self._get_indices(indices)
return [self.envs[i] for i in indices]