26
26
from og_marl .wrapped_environments .base import BaseEnvironment , Observations , ResetReturn , StepReturn
27
27
28
28
FLATLAND_MAP_CONFIGS = {
29
- "3trains " : {
29
+ "3_trains " : {
30
30
"num_trains" : 3 ,
31
31
"num_cities" : 2 ,
32
32
"width" : 25 ,
33
33
"height" : 25 ,
34
34
"max_episode_len" : 80 ,
35
35
},
36
- "5trains " : {
36
+ "5_trains " : {
37
37
"num_trains" : 5 ,
38
38
"num_cities" : 2 ,
39
39
"width" : 25 ,
40
40
"height" : 25 ,
41
41
"max_episode_len" : 100 ,
42
42
},
43
+ "20_trains" : {
44
+ "num_trains" : 20 ,
45
+ "num_cities" : 3 ,
46
+ "width" : 30 ,
47
+ "height" : 30 ,
48
+ "max_episode_len" : 100 ,
49
+ },
50
+ "30_trains" : {
51
+ "num_trains" : 30 ,
52
+ "num_cities" : 3 ,
53
+ "width" : 35 ,
54
+ "height" : 30 ,
55
+ "max_episode_len" : 100 ,
56
+ },
57
+ "40_trains" : {
58
+ "num_trains" : 40 ,
59
+ "num_cities" : 4 ,
60
+ "width" : 35 ,
61
+ "height" : 35 ,
62
+ "max_episode_len" : 100 ,
63
+ },
64
+ "50_trains" : {
65
+ "num_trains" : 50 ,
66
+ "num_cities" : 4 ,
67
+ "width" : 35 ,
68
+ "height" : 35 ,
69
+ "max_episode_len" : 100 ,
70
+ },
43
71
}
44
72
45
73
46
74
class Flatland (BaseEnvironment ):
47
75
def __init__ (self , map_name : str = "5_trains" ):
48
76
map_config = FLATLAND_MAP_CONFIGS [map_name ]
49
77
50
- self ._num_actions = 5
78
+ self .num_actions = 5
51
79
self .num_agents = map_config ["num_trains" ]
52
80
self ._num_cities = map_config ["num_cities" ]
53
81
self ._map_width = map_config ["width" ]
54
82
self ._map_height = map_config ["height" ]
55
83
self ._tree_depth = 2
56
84
57
- self .possible_agents = [f"{ i } " for i in range (self .num_agents )]
85
+ self .agents = [f"{ i } " for i in range (self .num_agents )]
58
86
59
87
self .rail_generator = sparse_rail_generator (max_num_cities = self ._num_cities )
60
88
@@ -75,19 +103,22 @@ def __init__(self, map_name: str = "5_trains"):
75
103
76
104
self ._obs_dim = 11 * sum (4 ** i for i in range (self ._tree_depth + 1 )) + 7
77
105
78
- self .action_spaces = {agent : Discrete (self ._num_actions ) for agent in self .possible_agents }
106
+ self .action_spaces = {agent : Discrete (self .num_actions ) for agent in self .agents }
79
107
self .observation_spaces = {
80
- agent : Box (- np .inf , np .inf , (self ._obs_dim ,)) for agent in self .possible_agents
108
+ agent : Box (- np .inf , np .inf , (self ._obs_dim ,)) for agent in self .agents
81
109
}
82
110
83
111
self .info_spec = {
84
112
"state" : np .zeros ((11 * self .num_agents ,), "float32" ),
85
113
"legals" : {
86
- agent : np .zeros ((self ._num_actions ,), "int64" ) for agent in self .possible_agents
114
+ agent : np .zeros ((self .num_actions ,), "int64" ) for agent in self .agents
87
115
},
88
116
}
89
117
90
- self .max_episode_length = map_config ["max_episode_len" ]
118
+
119
+ def render (self ) -> Any :
120
+ """Return frame for rendering"""
121
+ return self ._environment .render ()
91
122
92
123
def reset (self ) -> ResetReturn :
93
124
self ._done = False
@@ -116,7 +147,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
116
147
# Rewards
117
148
rewards = {
118
149
agent : np .array (all_rewards [int (agent )], dtype = "float32" )
119
- for agent in self .possible_agents
150
+ for agent in self .agents
120
151
}
121
152
122
153
# Legal actions
@@ -130,21 +161,25 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:
130
161
131
162
info = {"state" : state , "legals" : legal_actions }
132
163
133
- terminals = {agent : np .array (self ._done ) for agent in self .possible_agents }
134
- truncations = {agent : np .array (False ) for agent in self .possible_agents }
164
+ if self ._done :
165
+ num_arrived = sum (self ._environment .agents [int (agent )].state == 6 for agent in self .agents )
166
+ info ["arrived" ] = num_arrived
167
+
168
+ terminals = {agent : np .array (self ._done ) for agent in self .agents }
169
+ truncations = {agent : np .array (False ) for agent in self .agents }
135
170
136
171
return next_observations , rewards , terminals , truncations , info
137
172
138
173
def _get_legal_actions (self ) -> Dict [str , np .ndarray ]:
139
174
legal_actions = {}
140
- for agent in self .possible_agents :
175
+ for agent in self .agents :
141
176
agent_id = int (agent )
142
177
flatland_agent = self ._environment .agents [agent_id ]
143
178
144
179
if not self ._environment .action_required (
145
- flatland_agent . state , flatland_agent . speed_counter . is_cell_entry
180
+ flatland_agent
146
181
):
147
- legals = np .zeros (self ._num_actions , "float32" )
182
+ legals = np .zeros (self .num_actions , "float32" )
148
183
legals [0 ] = 1 # can only do nothng
149
184
else :
150
185
legals = np .ones (5 , "float32" )
@@ -155,7 +190,7 @@ def _get_legal_actions(self) -> Dict[str, np.ndarray]:
155
190
156
191
def _make_state_representation (self ) -> np .ndarray :
157
192
state = []
158
- for i , _ in enumerate (self .possible_agents ):
193
+ for i , _ in enumerate (self .agents ):
159
194
agent = self ._environment .agents [i ]
160
195
state .append (np .array (agent .target , "float32" ))
161
196
@@ -179,7 +214,7 @@ def _convert_observations(
179
214
info : Dict [str , Dict [int , np .ndarray ]],
180
215
) -> Observations :
181
216
new_observations = {}
182
- for i , agent in enumerate (self .possible_agents ):
217
+ for i , agent in enumerate (self .agents ):
183
218
agent_id = i
184
219
norm_observation = normalize_observation (
185
220
observations [agent_id ],
0 commit comments