-
Notifications
You must be signed in to change notification settings - Fork 607
Expand file tree
/
Copy pathturnbasedmultiagentenv.py
More file actions
327 lines (283 loc) · 12.3 KB
/
turnbasedmultiagentenv.py
File metadata and controls
327 lines (283 loc) · 12.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# ------------------------------------------------------------------------------------------------
# Copyright (c) 2020 Microsoft Corporation
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ------------------------------------------------------------------------------------------------
import time
from threading import Thread
from lxml import etree
from ray.rllib.env.multi_agent_env import MultiAgentEnv
import malmoenv
from malmoenv.core import EnvException
STEP_DELAY_TIME = 0.15
def _validate_config(xml, agent_configs):
"""
Verify that the supplied agent config is compatible with the mission XML.
"""
assert len(agent_configs) >= 2
xml = etree.fromstring(xml)
xml_agent_count = len(xml.findall("{http://ProjectMalmo.microsoft.com}AgentSection"))
assert len(agent_configs) == xml_agent_count
def _parse_address(address):
"""
Take addresses of various forms and convert them to a tuple of the form (HOST, PORT).
"""
if isinstance(address, int):
# Only a port number provided
return ("127.0.0.1", address)
if isinstance(address, str):
parts = address.split(":")
if len(parts) == 1:
# Port number as a string
return ("127.0.0.1", int(parts[0]))
if len(parts) == 2:
# String in the form "HOST:PORT"
return (parts[0], int(parts[1]))
if len(address) == 2 and isinstance(address[0], str) and isinstance(address[1], int):
# An already parsed address
return address
raise EnvException(f"{address} is not a valid address")
def _await_results(results):
"""
Receives a dictionary of result tasks and repopulates it with the final results after the tasks
complete.
"""
for agent_id, task in results.items():
results[agent_id] = task.wait()
def _default_env_factory(agent_id, xml, role, host_address, host_port, command_address, command_port):
"""
Default environment factory that fills out just enough settings to connect multiple game
instances into a single game session.
agent_id - The agent we're constructing the environment connection for.
xml - The mission XML.
role - The agent's role number. 0 == host agent.
host_address, host_port - Connection details for the game session host.
command_address, command_port - Connection details for the game instance the agent is controlling.
"""
env = malmoenv.make()
env.init(xml, host_port,
server=host_address,
server2=command_address,
port2=command_port,
role=role,
exp_uid="default_experiment_id"
)
return env
def _default_all_done_checker(env, obs, rewards, dones, infos):
"""
Returns True if any agent is reported as done.
"""
for done in dones.values():
if done:
return True
return False
# Wraps a MalmoEnv instance and provides async reset and sync step operations
# Reset operations need to be executed async as none of the connected environments will complete
# their reset operations until all environments have at least issued a reset request.
class _ConnectionContext:
def __init__(self, id, address, env):
"""
Wrapper around a connection to a game instance.
id - The agent id that is in control of the game instance.
address - (server, port) tuple for the command connection.
env - The MalmoEnv instance that is connected to the game instance.
"""
self.id = id
self.address = address
self.env = env
self.last_observation = None
# Async task status tracking
self._task_thread = None
self._task_result = None
def wait(self):
"""
Wait for the current async task to complete and return the result.
"""
assert self._task_thread is not None
self._task_thread.join()
self._task_thread = None
# We want to re-trow the exception if the task raised an error
if isinstance(self._task_result, Exception):
raise self._task_result
return self._task_result
def reset(self):
"""
Issue a reset request and return the async task immediately.
"""
assert self._task_thread is None
self._task_thread = Thread(target=self._reset_task, name=f"Agent '{self.id}' reset")
self._task_thread.start()
return self
def _reset_task(self):
try:
self._task_result = self.last_observation = self.env.reset()
except Exception as e:
self._task_result = e
def step(self, action):
"""
Issue a step request and return the async task immediately.
"""
self.last_observation, r, d, i = self.env.step(action)
return self.last_observation, r, d, i
def close(self):
"""
Shut down the Minecraft instance.
"""
self.env.close()
# Config for a single agent that will be present within the environment
class AgentConfig:
def __init__(self, id, address):
"""
Configuration details for an agent acting within the environment.
id - The agent's id as used by RLlib.
address - The address for the game instance for the agent to connect to.
"""
self.id = id
self.address = _parse_address(address)
# RLlib compatible multi-agent environment.
# This wraps multiple instances of MalmoEnv environments that are connected to their own Minecraft
# instances.
# The first agent defined in the agent_configs is treated as the primary Minecraft instance that
# will act as the game server.
class TurnBasedRllibMultiAgentEnv(MultiAgentEnv):
def __init__(self, xml, agent_configs, env_factory=None, all_done_checker=None):
"""
An RLlib compatible multi-agent environment.
NOTE: Will not work with turn based actions as all agent act together.
xml - The mission XML
agent_configs - A list of AgentConfigs to decribe the agents within the environment.
env_factory - Function to allow custom construction of the MalmoEnv instances.
This can be used to override the default inti parameter for the environment.
all_done_checker - Function to check if the "__all__" key should be set in the step done
dictionary. The default check returns True if any agent reports that
they're done.
"""
_validate_config(xml, agent_configs)
self._all_done_checker = all_done_checker or _default_all_done_checker
env_factory = env_factory or _default_env_factory
# The first agent is treated as the game session host
host_address = agent_configs[0].address
self._id = host_address
self._connections = {}
self._reset_request_time = 0
self._step = 0
role = 0
for agent_config in agent_configs:
env = env_factory(
agent_id=agent_config.id,
xml=xml,
role=role,
host_address=host_address[0],
host_port=host_address[1],
command_address=agent_config.address[0],
command_port=agent_config.address[1]
)
context = _ConnectionContext(
agent_config.id,
agent_config.address,
env
)
self._connections[agent_config.id] = context
role += 1
def get_observation_space(self, agent_id):
return self._connections[agent_id].env.observation_space
def get_action_space(self, agent_id):
return self._connections[agent_id].env.action_space
def reset(self):
self._step = 0
obs = {}
request_time = time.perf_counter()
for agent_id, connection in self._connections.items():
obs[agent_id] = connection.reset()
# All reset operations must be issued asynchronously as none of the Minecraft instances
# will complete their reset requests until all agents have issued a reset request
_await_results(obs)
self._reset_request_time = time.perf_counter() - request_time
return obs
def step(self, actions):
self._step += 1
results = {}
request_time = time.perf_counter()
done = False
for agent_id, action in actions.items():
if not done:
# We need to wait a small amount of time between each agent's step request to give
# the Minecraft instances time to sync up and agree whose turn to act it is
time.sleep(STEP_DELAY_TIME)
o, r, done, i = self._connections[agent_id].step(action)
else:
# If any of the agents report themselves as "done", then we should stop taking turns
# so generate a dummy step result based on the last observation so that training
# receives valid looking data
o = self._connections[agent_id].last_observation
r = 0.0
i = {}
assert self._connections[agent_id].env.observation_space.contains(o), f"Shape={o.shape}"
results[agent_id] = (o, r, done, i)
request_time = time.perf_counter() - request_time
# We need to repack the individual step results into dictionaries per data type to conform
# with RLlib's requirements
obs = {
agent_id: result[0]
for agent_id, result in results.items()
}
rewards = {
agent_id: result[1]
for agent_id, result in results.items()
}
dones = {
agent_id: result[2]
for agent_id, result in results.items()
}
infos = {
agent_id: result[3]
for agent_id, result in results.items()
}
# Pass the results to the done checker to set the required __all__ value
dones["__all__"] = self._all_done_checker(self, obs, rewards, dones, infos)
return obs, rewards, dones, infos
def close(self):
for connection in self._connections.values():
try:
connection.close()
except Exception as e:
message = getattr(e, "message", e)
print(f"Error closing environment: {message}")
# As Malmo returns stale observations for actions, this wrapper can be used to sync observations
# and actions by issuing an idle action after the policy generated action to query the resultant
# state of the environment
class SyncRllibMultiAgentEnv(MultiAgentEnv):
def __init__(self, env, idle_action):
self.env = env
self.idle_action = idle_action
def reset(self):
return self.env.reset()
def step(self, actions):
# The first step request to the environment returns stale data, so we want to ignore it
# unless Malmo reports one of the instances as "done"
o, r, d, i = self.env.step(actions)
for done in d.values():
if done:
return o, r, d, i
# The second step request is really just a query for the environment state. When used with
# the turn based environment, there is a delay injected before the requests which allows
# the environment to settle into the new state
return self.env.step({
key: self.idle_action
for key in actions
})
def close(self):
return self.env.close()