Skip to content

Commit 15fcf95

Browse files
author
Jonathan Harper
committed
Fix subprocess model saving on Windows
On Windows the interrupt for subprocesses works in a different way from OSX/Linux. The result is that child subprocesses and their pipes may close while the parent process is still running during a keyboard (ctrl+C) interrupt. To handle this, this change adds handling for EOFError and BrokenPipeError exceptions when interacting with subprocess environments. Additional management is also added to be sure when using parallel runs using the "num-runs" option that the threads for each run are joined and KeyboardInterrupts are handled. These changes made the "_win_handler" we used to specially manage interrupts on Windows unnecessary, so they have been removed.
1 parent d2ad6e8 commit 15fcf95

File tree

4 files changed

+29
-29
lines changed

4 files changed

+29
-29
lines changed

ml-agents-envs/mlagents/envs/subprocess_environment.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,24 @@ class UnityEnvWorker(NamedTuple):
2727
conn: Connection
2828

2929
def send(self, name: str, payload=None):
30-
cmd = EnvironmentCommand(name, payload)
31-
self.conn.send(cmd)
30+
try:
31+
cmd = EnvironmentCommand(name, payload)
32+
self.conn.send(cmd)
33+
except (BrokenPipeError, EOFError):
34+
raise KeyboardInterrupt
3235

3336
def recv(self) -> EnvironmentResponse:
34-
response: EnvironmentResponse = self.conn.recv()
35-
return response
37+
try:
38+
response: EnvironmentResponse = self.conn.recv()
39+
return response
40+
except (BrokenPipeError, EOFError):
41+
raise KeyboardInterrupt
3642

3743
def close(self):
38-
self.conn.send(EnvironmentCommand("close"))
44+
try:
45+
self.conn.send(EnvironmentCommand('close'))
46+
except (BrokenPipeError, EOFError):
47+
pass
3948
self.process.join()
4049

4150

@@ -87,10 +96,10 @@ def create_worker(
8796
env_factory: Callable[[int], BaseUnityEnvironment]
8897
) -> UnityEnvWorker:
8998
parent_conn, child_conn = Pipe()
99+
90100
# Need to use cloudpickle for the env factory function since function objects aren't picklable
91101
# on Windows as of Python 3.6.
92102
pickled_env_factory = cloudpickle.dumps(env_factory)
93-
94103
child_process = Process(target=worker, args=(child_conn, pickled_env_factory, worker_id))
95104
child_process.start()
96105
return UnityEnvWorker(child_process, worker_id, parent_conn)

ml-agents/mlagents/trainers/learn.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue):
5151
trainer_config_path = run_options['<trainer-config-path>']
5252
# Recognize and use docker volume if one is passed as an argument
5353
if not docker_target_name:
54-
model_path = './models/{run_id}'.format(run_id=run_id)
54+
model_path = './models/{run_id}-{sub_id}'.format(run_id=run_id, sub_id=sub_id)
5555
summaries_dir = './summaries'
5656
else:
5757
trainer_config_path = \
@@ -63,9 +63,10 @@ def run_training(sub_id: int, run_seed: int, run_options, process_queue):
6363
'/{docker_target_name}/{curriculum_folder}'.format(
6464
docker_target_name=docker_target_name,
6565
curriculum_folder=curriculum_folder)
66-
model_path = '/{docker_target_name}/models/{run_id}'.format(
66+
model_path = '/{docker_target_name}/models/{run_id}-{sub_id}'.format(
6767
docker_target_name=docker_target_name,
68-
run_id=run_id)
68+
run_id=run_id,
69+
sub_id=sub_id)
6970
summaries_dir = '/{docker_target_name}/summaries'.format(
7071
docker_target_name=docker_target_name)
7172

@@ -274,6 +275,14 @@ def main():
274275
while process_queue.get() is not True:
275276
continue
276277

278+
# Wait for jobs to complete. Otherwise we'll have an extra
279+
# unhandled KeyboardInterrupt if we end early.
280+
try:
281+
for job in jobs:
282+
job.join()
283+
except KeyboardInterrupt:
284+
pass
285+
277286
# For python debugger to directly run this script
278287
if __name__ == "__main__":
279288
main()

ml-agents/mlagents/trainers/tests/test_learn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_run_training(load_config, create_environment_factory, subproc_env_mock)
4141
with patch.object(TrainerController, "start_learning", MagicMock()):
4242
learn.run_training(0, 0, basic_options(), MagicMock())
4343
mock_init.assert_called_once_with(
44-
'./models/ppo',
44+
'./models/ppo-0',
4545
'./summaries',
4646
'ppo-0',
4747
50000,
@@ -74,5 +74,5 @@ def test_docker_target_path(load_config, create_environment_factory, subproc_env
7474
with patch.object(TrainerController, "start_learning", MagicMock()):
7575
learn.run_training(0, 0, options_with_docker_target, MagicMock())
7676
mock_init.assert_called_once()
77-
assert(mock_init.call_args[0][0] == '/dockertarget/models/ppo')
77+
assert(mock_init.call_args[0][0] == '/dockertarget/models/ppo-0')
7878
assert(mock_init.call_args[0][1] == '/dockertarget/summaries')

ml-agents/mlagents/trainers/trainer_controller.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
import logging
77
import shutil
88
import sys
9-
if sys.platform.startswith('win'):
10-
import win32api
11-
import win32con
129
from typing import *
1310

1411
import numpy as np
@@ -104,18 +101,6 @@ def _save_model_when_interrupted(self, steps=0):
104101
'while the graph is generated.')
105102
self._save_model(steps)
106103

107-
def _win_handler(self, event):
108-
"""
109-
This function gets triggered after ctrl-c or ctrl-break is pressed
110-
under Windows platform.
111-
"""
112-
if event in (win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT):
113-
self._save_model_when_interrupted(self.global_step)
114-
self._export_graph()
115-
sys.exit()
116-
return True
117-
return False
118-
119104
def _write_training_metrics(self):
120105
"""
121106
Write all CSV metrics
@@ -223,9 +208,6 @@ def start_learning(self, env: BaseUnityEnvironment, trainer_config):
223208
for brain_name, trainer in self.trainers.items():
224209
trainer.write_tensorboard_text('Hyperparameters',
225210
trainer.parameters)
226-
if sys.platform.startswith('win'):
227-
# Add the _win_handler function to the windows console's handler function list
228-
win32api.SetConsoleCtrlHandler(self._win_handler, True)
229211
try:
230212
curr_info = self._reset_env(env)
231213
while any([t.get_step <= t.get_max_steps \

0 commit comments

Comments
 (0)