-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathcreate_env.py
More file actions
115 lines (101 loc) · 4.47 KB
/
create_env.py
File metadata and controls
115 lines (101 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import enum
import time
from copy import deepcopy
import numpy as np
from gymnasium import Wrapper
from loguru import logger
from pogema import AnimationConfig, AnimationMonitor, pogema_v0
from pogema.generator import generate_from_possible_targets, generate_new_target
from pogema.wrappers.metrics import AgentsDensityWrapper, RuntimeMetricWrapper
from pogema_toolbox.create_env import MultiMapWrapper
class ProvideFutureTargetsWrapper(Wrapper):
def _get_lifelong_global_targets_xy(self):
all_goals = []
cur_goals = self.grid.get_targets_xy()
generators = deepcopy(self.random_generators)
for agent_idx in range(self.grid_config.num_agents):
distance = 0
cur_goal = cur_goals[agent_idx]
goals = [cur_goal]
while distance < self.grid_config.max_episode_steps + 100:
if self.grid_config.possible_targets_xy is None:
new_goal = generate_new_target(
generators[agent_idx],
self.grid.point_to_component,
self.grid.component_to_points,
cur_goal,
)
else:
new_goal = generate_from_possible_targets(
generators[agent_idx],
self.grid_config.possible_targets_xy,
cur_goal,
)
new_goal = (
new_goal[0] + self.grid_config.obs_radius,
new_goal[1] + self.grid_config.obs_radius,
)
distance += abs(cur_goal[0] - new_goal[0]) + abs(
cur_goal[1] - new_goal[1]
)
cur_goal = new_goal
goals.append(cur_goal)
all_goals.append(goals)
return all_goals
def reset(self, **kwargs):
observations, infos = self.env.reset(seed=self.env.grid_config.seed)
observations[0]["after_reset"] = True
observations[0]["max_episode_steps"] = self.env.grid_config.max_episode_steps
if self.env.grid_config.on_target == "restart":
global_lifelong_targets_xy = self._get_lifelong_global_targets_xy()
for idx, obs in enumerate(observations):
obs["global_lifelong_targets_xy"] = global_lifelong_targets_xy[idx]
return observations, infos
class LogActions(Wrapper):
def __init__(self, env):
super().__init__(env)
self.made_actions = None
self.init_positions = None
def step(self, actions):
observations, rewards, terminated, truncated, infos = self.env.step(actions)
for i, action in enumerate(actions):
self.made_actions[i].append(action)
if all(terminated) or all(truncated):
infos[0]["metrics"]["made_actions"] = self.made_actions
infos[0]["metrics"]["init_positions"] = self.init_positions
if self.env.grid_config.on_target == "restart":
infos[0]["metrics"][
"global_lifelong_targets_xy"
] = self.global_lifelong_targets_xy
return observations, rewards, terminated, truncated, infos
def reset(self, **kwargs):
observations, info = self.env.reset(**kwargs)
self.made_actions = [[] for _ in observations]
self.init_positions = [obs["global_xy"] for obs in observations]
if self.env.grid_config.on_target == "restart":
self.global_lifelong_targets_xy = [
[[int(x), int(y)] for x, y in obs["global_lifelong_targets_xy"]]
for obs in observations
]
return observations, info
def create_eval_env(config):
env = pogema_v0(grid_config=config)
env = AgentsDensityWrapper(env)
env = MultiMapWrapper(env)
if config.with_animation:
logger.debug("Wrapping environment with AnimationMonitor")
env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None))
env = RuntimeMetricWrapper(env)
return env
def create_logging_env(config):
env = pogema_v0(grid_config=config)
env = AgentsDensityWrapper(env)
env = ProvideFutureTargetsWrapper(env)
env = MultiMapWrapper(env)
env = LogActions(env)
if config.with_animation:
logger.debug("Wrapping environment with AnimationMonitor")
env = AnimationMonitor(env, AnimationConfig(save_every_idx_episode=None))
# Adding runtime metrics
env = RuntimeMetricWrapper(env)
return env