Skip to content

Commit dda50ee

Browse files
committed
Added flatland datasets.
1 parent b89f793 commit dda50ee

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

og_marl/vault_utils/download_vault.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@
7878
"url": "https://huggingface.co/datasets/InstaDeepAI/og-marl/resolve/main/core/gymnasium_mamujoco/6halfcheetah.zip"
7979
},
8080
},
81+
"flatland": {
82+
"20_trains": {
83+
"url": "https://huggingface.co/datasets/InstaDeepAI/og-marl/resolve/main/core/flatland/20_trains.zip"
84+
},
85+
"30_trains": {
86+
"url": "https://huggingface.co/datasets/InstaDeepAI/og-marl/resolve/main/core/flatland/30_trains.zip"
87+
},
88+
},
8189
},
8290
"cfcql": {
8391
"smac_v1": {

og_marl/wrapped_environments/flatland_wrapper.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,63 @@
2626
from og_marl.wrapped_environments.base import BaseEnvironment, Observations, ResetReturn, StepReturn
2727

2828
FLATLAND_MAP_CONFIGS = {
29-
"3trains": {
29+
"3_trains": {
3030
"num_trains": 3,
3131
"num_cities": 2,
3232
"width": 25,
3333
"height": 25,
3434
"max_episode_len": 80,
3535
},
36-
"5trains": {
36+
"5_trains": {
3737
"num_trains": 5,
3838
"num_cities": 2,
3939
"width": 25,
4040
"height": 25,
4141
"max_episode_len": 100,
4242
},
43+
"20_trains": {
44+
"num_trains": 20,
45+
"num_cities": 3,
46+
"width": 30,
47+
"height": 30,
48+
"max_episode_len": 100,
49+
},
50+
"30_trains": {
51+
"num_trains": 30,
52+
"num_cities": 3,
53+
"width": 35,
54+
"height": 30,
55+
"max_episode_len": 100,
56+
},
57+
"40_trains": {
58+
"num_trains": 40,
59+
"num_cities": 4,
60+
"width": 35,
61+
"height": 35,
62+
"max_episode_len": 100,
63+
},
64+
"50_trains": {
65+
"num_trains": 50,
66+
"num_cities": 4,
67+
"width": 35,
68+
"height": 35,
69+
"max_episode_len": 100,
70+
},
4371
}
4472

4573

4674
class Flatland(BaseEnvironment):
4775
def __init__(self, map_name: str = "5_trains"):
4876
map_config = FLATLAND_MAP_CONFIGS[map_name]
4977

50-
self._num_actions = 5
78+
self.num_actions = 5
5179
self.num_agents = map_config["num_trains"]
5280
self._num_cities = map_config["num_cities"]
5381
self._map_width = map_config["width"]
5482
self._map_height = map_config["height"]
5583
self._tree_depth = 2
5684

57-
self.possible_agents = [f"{i}" for i in range(self.num_agents)]
85+
self.agents = [f"{i}" for i in range(self.num_agents)]
5886

5987
self.rail_generator = sparse_rail_generator(max_num_cities=self._num_cities)
6088

@@ -75,19 +103,22 @@ def __init__(self, map_name: str = "5_trains"):
75103

76104
self._obs_dim = 11 * sum(4**i for i in range(self._tree_depth + 1)) + 7
77105

78-
self.action_spaces = {agent: Discrete(self._num_actions) for agent in self.possible_agents}
106+
self.action_spaces = {agent: Discrete(self.num_actions) for agent in self.agents}
79107
self.observation_spaces = {
80-
agent: Box(-np.inf, np.inf, (self._obs_dim,)) for agent in self.possible_agents
108+
agent: Box(-np.inf, np.inf, (self._obs_dim,)) for agent in self.agents
81109
}
82110

83111
self.info_spec = {
84112
"state": np.zeros((11 * self.num_agents,), "float32"),
85113
"legals": {
86-
agent: np.zeros((self._num_actions,), "int64") for agent in self.possible_agents
114+
agent: np.zeros((self.num_actions,), "int64") for agent in self.agents
87115
},
88116
}
89117

90-
self.max_episode_length = map_config["max_episode_len"]
118+
119+
def render(self) -> Any:
120+
"""Return frame for rendering"""
121+
return self._environment.render()
91122

92123
def reset(self) -> ResetReturn:
93124
self._done = False
@@ -116,7 +147,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
116147
# Rewards
117148
rewards = {
118149
agent: np.array(all_rewards[int(agent)], dtype="float32")
119-
for agent in self.possible_agents
150+
for agent in self.agents
120151
}
121152

122153
# Legal actions
@@ -130,21 +161,25 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
130161

131162
info = {"state": state, "legals": legal_actions}
132163

133-
terminals = {agent: np.array(self._done) for agent in self.possible_agents}
134-
truncations = {agent: np.array(False) for agent in self.possible_agents}
164+
if self._done:
165+
num_arrived = sum(self._environment.agents[int(agent)].state == 6 for agent in self.agents)
166+
info["arrived"] = num_arrived
167+
168+
terminals = {agent: np.array(self._done) for agent in self.agents}
169+
truncations = {agent: np.array(False) for agent in self.agents}
135170

136171
return next_observations, rewards, terminals, truncations, info
137172

138173
def _get_legal_actions(self) -> Dict[str, np.ndarray]:
139174
legal_actions = {}
140-
for agent in self.possible_agents:
175+
for agent in self.agents:
141176
agent_id = int(agent)
142177
flatland_agent = self._environment.agents[agent_id]
143178

144179
if not self._environment.action_required(
145-
flatland_agent.state, flatland_agent.speed_counter.is_cell_entry
180+
flatland_agent
146181
):
147-
legals = np.zeros(self._num_actions, "float32")
182+
legals = np.zeros(self.num_actions, "float32")
148183
legals[0] = 1 # can only do nothng
149184
else:
150185
legals = np.ones(5, "float32")
@@ -155,7 +190,7 @@ def _get_legal_actions(self) -> Dict[str, np.ndarray]:
155190

156191
def _make_state_representation(self) -> np.ndarray:
157192
state = []
158-
for i, _ in enumerate(self.possible_agents):
193+
for i, _ in enumerate(self.agents):
159194
agent = self._environment.agents[i]
160195
state.append(np.array(agent.target, "float32"))
161196

@@ -179,7 +214,7 @@ def _convert_observations(
179214
info: Dict[str, Dict[int, np.ndarray]],
180215
) -> Observations:
181216
new_observations = {}
182-
for i, agent in enumerate(self.possible_agents):
217+
for i, agent in enumerate(self.agents):
183218
agent_id = i
184219
norm_observation = normalize_observation(
185220
observations[agent_id],

0 commit comments

Comments
 (0)