|
| 1 | +import os |
| 2 | +from mlagents.trainers.exception import UnityTrainerException |
| 3 | + |
| 4 | + |
| 5 | +def validate_existing_directories( |
| 6 | + output_path: str, resume: bool, force: bool, init_path: str = None |
| 7 | +) -> None: |
| 8 | + """ |
| 9 | + Validates that if the run_id model exists, we do not overwrite it unless --force is specified. |
| 10 | + Throws an exception if resume isn't specified and run_id exists. Throws an exception |
| 11 | + if --resume is specified and run-id was not found. |
| 12 | + :param model_path: The model path specified. |
| 13 | + :param summary_path: The summary path to be used. |
| 14 | + :param resume: Whether or not the --resume flag was passed. |
| 15 | + :param force: Whether or not the --force flag was passed. |
| 16 | + """ |
| 17 | + |
| 18 | + output_path_exists = os.path.isdir(output_path) |
| 19 | + |
| 20 | + if output_path_exists: |
| 21 | + if not resume and not force: |
| 22 | + raise UnityTrainerException( |
| 23 | + "Previous data from this run ID was found. " |
| 24 | + "Either specify a new run ID, use --resume to resume this run, " |
| 25 | + "or use the --force parameter to overwrite existing data." |
| 26 | + ) |
| 27 | + else: |
| 28 | + if resume: |
| 29 | + raise UnityTrainerException( |
| 30 | + "Previous data from this run ID was not found. " |
| 31 | + "Train a new run by removing the --resume flag." |
| 32 | + ) |
| 33 | + |
| 34 | + # Verify init path if specified. |
| 35 | + if init_path is not None: |
| 36 | + if not os.path.isdir(init_path): |
| 37 | + raise UnityTrainerException( |
| 38 | + "Could not initialize from {}. " |
| 39 | + "Make sure models have already been saved with that run ID.".format( |
| 40 | + init_path |
| 41 | + ) |
| 42 | + ) |
0 commit comments