|
6 | 6 | import glob
|
7 | 7 | import logging
|
8 | 8 | import shutil
|
| 9 | +import sys |
| 10 | +if sys.platform.startswith('win'): |
| 11 | + import win32api |
| 12 | + import win32con |
9 | 13 |
|
10 | 14 | import yaml
|
11 | 15 | import re
|
@@ -103,6 +107,7 @@ def __init__(self, env_path, run_id, save_freq, curriculum_folder,
|
103 | 107 | self.keep_checkpoints = keep_checkpoints
|
104 | 108 | self.trainers = {}
|
105 | 109 | self.seed = seed
|
| 110 | + self.global_step = 0 |
106 | 111 | np.random.seed(self.seed)
|
107 | 112 | tf.set_random_seed(self.seed)
|
108 | 113 | self.env = UnityEnvironment(file_name=env_path,
|
@@ -181,6 +186,23 @@ def _save_model(self,steps=0):
|
181 | 186 | self.trainers[brain_name].save_model()
|
182 | 187 | self.logger.info('Saved Model')
|
183 | 188 |
|
| 189 | + def _save_model_when_interrupted(self, steps=0): |
| 190 | + self.logger.info('Learning was interrupted. Please wait ' |
| 191 | + 'while the graph is generated.') |
| 192 | + self._save_model(steps) |
| 193 | + |
| 194 | + def _win_handler(self, event): |
| 195 | + """ |
| 196 | + This function gets triggered after ctrl-c or ctrl-break is pressed |
| 197 | + under Windows platform. |
| 198 | + """ |
| 199 | + if event in (win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT): |
| 200 | + self._save_model_when_interrupted(self.global_step) |
| 201 | + self._export_graph() |
| 202 | + sys.exit() |
| 203 | + return True |
| 204 | + return False |
| 205 | + |
184 | 206 | def _export_graph(self):
|
185 | 207 | """
|
186 | 208 | Exports latest saved models to .bytes format for Unity embedding.
|
@@ -288,12 +310,14 @@ def start_learning(self):
|
288 | 310 | self._initialize_trainers(trainer_config)
|
289 | 311 | for _, t in self.trainers.items():
|
290 | 312 | self.logger.info(t)
|
291 |
| - global_step = 0 # This is only for saving the model |
292 | 313 | curr_info = self._reset_env()
|
293 | 314 | if self.train_model:
|
294 | 315 | for brain_name, trainer in self.trainers.items():
|
295 | 316 | trainer.write_tensorboard_text('Hyperparameters',
|
296 | 317 | trainer.parameters)
|
| 318 | + if sys.platform.startswith('win'): |
| 319 | + # Add the _win_handler function to the windows console's handler function list |
| 320 | + win32api.SetConsoleCtrlHandler(self._win_handler, True) |
297 | 321 | try:
|
298 | 322 | while any([t.get_step <= t.get_max_steps \
|
299 | 323 | for k, t in self.trainers.items()]) \
|
@@ -353,31 +377,27 @@ def start_learning(self):
|
353 | 377 | # Write training statistics to Tensorboard.
|
354 | 378 | if self.meta_curriculum is not None:
|
355 | 379 | trainer.write_summary(
|
356 |
| - global_step, |
| 380 | + self.global_step, |
357 | 381 | lesson_num=self.meta_curriculum
|
358 | 382 | .brains_to_curriculums[brain_name]
|
359 | 383 | .lesson_num)
|
360 | 384 | else:
|
361 |
| - trainer.write_summary(global_step) |
| 385 | + trainer.write_summary(self.global_step) |
362 | 386 | if self.train_model \
|
363 | 387 | and trainer.get_step <= trainer.get_max_steps:
|
364 | 388 | trainer.increment_step_and_update_last_reward()
|
365 |
| - global_step += 1 |
366 |
| - if global_step % self.save_freq == 0 and global_step != 0 \ |
| 389 | + self.global_step += 1 |
| 390 | + if self.global_step % self.save_freq == 0 and self.global_step != 0 \ |
367 | 391 | and self.train_model:
|
368 | 392 | # Save Tensorflow model
|
369 |
| - self._save_model(steps=global_step) |
| 393 | + self._save_model(steps=self.global_step) |
370 | 394 | curr_info = new_info
|
371 | 395 | # Final save Tensorflow model
|
372 |
| - if global_step != 0 and self.train_model: |
373 |
| - self._save_model(steps=global_step) |
| 396 | + if self.global_step != 0 and self.train_model: |
| 397 | + self._save_model(steps=self.global_step) |
374 | 398 | except KeyboardInterrupt:
|
375 |
| - print('--------------------------Now saving model--------------' |
376 |
| - '-----------') |
377 | 399 | if self.train_model:
|
378 |
| - self.logger.info('Learning was interrupted. Please wait ' |
379 |
| - 'while the graph is generated.') |
380 |
| - self._save_model(steps=global_step) |
| 400 | + self._save_model_when_interrupted(steps=self.global_step) |
381 | 401 | pass
|
382 | 402 | self.env.close()
|
383 | 403 | if self.train_model:
|
|
0 commit comments