Skip to content

Commit 5510062

Browse files
authored
fixed the windows ctrl-c bug (#1558)
* Documentation tweaks and updates (#1479) * Add blurb about using the --load flag in the intro guide, and typo fix. * Add section in tutorial to create multiple area learning environment. * Add mention of Done() method in agent design * fixed the windows ctrl-c bug * fixed typo * removed some uncessary printing * nothing * make the import of the win api conditional * removved the duplicate code * added the ability to use python debugger on ml-agents * added newline at the end, changed the import to be complete path * changed the info.log into policy.export_model, changed the sys.platform to use startswith * fixed a bug * remove the printing of the path * tweaked the info message to notify the user about the expected error message * removed some logging according to comments * removed the sys import * Revert "Documentation tweaks and updates (#1479)" This reverts commit 84ef07a. * resolved the model path comment
1 parent 4a52af9 commit 5510062

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

ml-agents/mlagents/trainers/learn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import numpy as np
77
from docopt import docopt
88

9-
from .trainer_controller import TrainerController
10-
from .exception import TrainerError
9+
from mlagents.trainers.trainer_controller import TrainerController
10+
from mlagents.trainers.exception import TrainerError
1111

1212

1313
def run_training(sub_id, run_seed, run_options, process_queue):
@@ -117,3 +117,7 @@ def main():
117117
# Wait for signal that environment has successfully launched
118118
while process_queue.get() is not True:
119119
continue
120+
121+
# For python debugger to directly run this script
122+
if __name__ == "__main__":
123+
main()

ml-agents/mlagents/trainers/policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def export_model(self):
179179
clear_devices=True, initializer_nodes='', input_saver='',
180180
restore_op_name='save/restore_all',
181181
filename_tensor_name='save/Const:0')
182+
logger.info('Exported ' + self.model_path + '.bytes file')
182183

183184
def _process_graph(self):
184185
"""

ml-agents/mlagents/trainers/trainer_controller.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import glob
77
import logging
88
import shutil
9+
import sys
10+
if sys.platform.startswith('win'):
11+
import win32api
12+
import win32con
913

1014
import yaml
1115
import re
@@ -103,6 +107,7 @@ def __init__(self, env_path, run_id, save_freq, curriculum_folder,
103107
self.keep_checkpoints = keep_checkpoints
104108
self.trainers = {}
105109
self.seed = seed
110+
self.global_step = 0
106111
np.random.seed(self.seed)
107112
tf.set_random_seed(self.seed)
108113
self.env = UnityEnvironment(file_name=env_path,
@@ -181,6 +186,23 @@ def _save_model(self,steps=0):
181186
self.trainers[brain_name].save_model()
182187
self.logger.info('Saved Model')
183188

189+
def _save_model_when_interrupted(self, steps=0):
190+
self.logger.info('Learning was interrupted. Please wait '
191+
'while the graph is generated.')
192+
self._save_model(steps)
193+
194+
def _win_handler(self, event):
195+
"""
196+
This function gets triggered after ctrl-c or ctrl-break is pressed
197+
under Windows platform.
198+
"""
199+
if event in (win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT):
200+
self._save_model_when_interrupted(self.global_step)
201+
self._export_graph()
202+
sys.exit()
203+
return True
204+
return False
205+
184206
def _export_graph(self):
185207
"""
186208
Exports latest saved models to .bytes format for Unity embedding.
@@ -288,12 +310,14 @@ def start_learning(self):
288310
self._initialize_trainers(trainer_config)
289311
for _, t in self.trainers.items():
290312
self.logger.info(t)
291-
global_step = 0 # This is only for saving the model
292313
curr_info = self._reset_env()
293314
if self.train_model:
294315
for brain_name, trainer in self.trainers.items():
295316
trainer.write_tensorboard_text('Hyperparameters',
296317
trainer.parameters)
318+
if sys.platform.startswith('win'):
319+
# Add the _win_handler function to the windows console's handler function list
320+
win32api.SetConsoleCtrlHandler(self._win_handler, True)
297321
try:
298322
while any([t.get_step <= t.get_max_steps \
299323
for k, t in self.trainers.items()]) \
@@ -353,31 +377,27 @@ def start_learning(self):
353377
# Write training statistics to Tensorboard.
354378
if self.meta_curriculum is not None:
355379
trainer.write_summary(
356-
global_step,
380+
self.global_step,
357381
lesson_num=self.meta_curriculum
358382
.brains_to_curriculums[brain_name]
359383
.lesson_num)
360384
else:
361-
trainer.write_summary(global_step)
385+
trainer.write_summary(self.global_step)
362386
if self.train_model \
363387
and trainer.get_step <= trainer.get_max_steps:
364388
trainer.increment_step_and_update_last_reward()
365-
global_step += 1
366-
if global_step % self.save_freq == 0 and global_step != 0 \
389+
self.global_step += 1
390+
if self.global_step % self.save_freq == 0 and self.global_step != 0 \
367391
and self.train_model:
368392
# Save Tensorflow model
369-
self._save_model(steps=global_step)
393+
self._save_model(steps=self.global_step)
370394
curr_info = new_info
371395
# Final save Tensorflow model
372-
if global_step != 0 and self.train_model:
373-
self._save_model(steps=global_step)
396+
if self.global_step != 0 and self.train_model:
397+
self._save_model(steps=self.global_step)
374398
except KeyboardInterrupt:
375-
print('--------------------------Now saving model--------------'
376-
'-----------')
377399
if self.train_model:
378-
self.logger.info('Learning was interrupted. Please wait '
379-
'while the graph is generated.')
380-
self._save_model(steps=global_step)
400+
self._save_model_when_interrupted(steps=self.global_step)
381401
pass
382402
self.env.close()
383403
if self.train_model:

0 commit comments

Comments
 (0)