Skip to content

Commit efc9d84

Browse files
author
Chris Elion
authored
Develop yaml json loading errors (#2601)
* WIP cleanup loading * better exceptions for parser errors - refer to online lint tools * feedback - rename variable
1 parent 2e0bab8 commit efc9d84

File tree

7 files changed

+140
-35
lines changed

7 files changed

+140
-35
lines changed

ml-agents/mlagents/trainers/curriculum.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import math
44

5-
from .exception import CurriculumError
5+
from .exception import CurriculumConfigError, CurriculumLoadingError
66

77
import logging
88

@@ -23,14 +23,8 @@ def __init__(self, location, default_reset_parameters):
2323
# The name of the brain should be the basename of the file without the
2424
# extension.
2525
self._brain_name = os.path.basename(location).split(".")[0]
26+
self.data = Curriculum.load_curriculum_file(location)
2627

27-
try:
28-
with open(location) as data_file:
29-
self.data = json.load(data_file)
30-
except IOError:
31-
raise CurriculumError("The file {0} could not be found.".format(location))
32-
except UnicodeDecodeError:
33-
raise CurriculumError("There was an error decoding {}".format(location))
3428
self.smoothing_value = 0
3529
for key in [
3630
"parameters",
@@ -40,7 +34,7 @@ def __init__(self, location, default_reset_parameters):
4034
"signal_smoothing",
4135
]:
4236
if key not in self.data:
43-
raise CurriculumError(
37+
raise CurriculumConfigError(
4438
"{0} does not contain a " "{1} field.".format(location, key)
4539
)
4640
self.smoothing_value = 0
@@ -51,12 +45,12 @@ def __init__(self, location, default_reset_parameters):
5145
parameters = self.data["parameters"]
5246
for key in parameters:
5347
if key not in default_reset_parameters:
54-
raise CurriculumError(
48+
raise CurriculumConfigError(
5549
"The parameter {0} in Curriculum {1} is not present in "
5650
"the Environment".format(key, location)
5751
)
5852
if len(parameters[key]) != self.max_lesson_num + 1:
59-
raise CurriculumError(
53+
raise CurriculumConfigError(
6054
"The parameter {0} in Curriculum {1} must have {2} values "
6155
"but {3} were found".format(
6256
key, location, self.max_lesson_num + 1, len(parameters[key])
@@ -117,3 +111,27 @@ def get_config(self, lesson=None):
117111
for key in parameters:
118112
config[key] = parameters[key][lesson]
119113
return config
114+
115+
@staticmethod
116+
def load_curriculum_file(location):
117+
try:
118+
with open(location) as data_file:
119+
return Curriculum._load_curriculum(data_file)
120+
except IOError:
121+
raise CurriculumLoadingError(
122+
"The file {0} could not be found.".format(location)
123+
)
124+
except UnicodeDecodeError:
125+
raise CurriculumLoadingError(
126+
"There was an error decoding {}".format(location)
127+
)
128+
129+
@staticmethod
130+
def _load_curriculum(fp):
131+
try:
132+
return json.load(fp)
133+
except json.decoder.JSONDecodeError as e:
134+
raise CurriculumLoadingError(
135+
"Error parsing JSON file. Please check for formatting errors. "
136+
"A tool such as https://jsonlint.com/ can be helpful with this."
137+
) from e

ml-agents/mlagents/trainers/exception.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@ class CurriculumError(TrainerError):
1919
pass
2020

2121

22+
class CurriculumLoadingError(CurriculumError):
23+
"""
24+
Any error related to loading the Curriculum config file.
25+
"""
26+
27+
pass
28+
29+
30+
class CurriculumConfigError(CurriculumError):
31+
"""
32+
Any error related to processing the Curriculum config file.
33+
"""
34+
35+
pass
36+
37+
2238
class MetaCurriculumError(TrainerError):
2339
"""
2440
Any error related to the configuration of a metacurriculum.

ml-agents/mlagents/trainers/learn.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
import glob
99
import shutil
1010
import numpy as np
11-
import yaml
12-
from typing import Any, Callable, Dict, Optional, List, NamedTuple
11+
12+
from typing import Any, Callable, Optional, List, NamedTuple
1313

1414

1515
from mlagents.trainers.trainer_controller import TrainerController
1616
from mlagents.trainers.exception import TrainerError
1717
from mlagents.trainers.meta_curriculum import MetaCurriculumError, MetaCurriculum
18-
from mlagents.trainers.trainer_util import initialize_trainers
18+
from mlagents.trainers.trainer_util import initialize_trainers, load_config
1919
from mlagents.envs.environment import UnityEnvironment
2020
from mlagents.envs.sampler_class import SamplerManager
21-
from mlagents.envs.exception import UnityEnvironmentException, SamplerException
21+
from mlagents.envs.exception import SamplerException
2222
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
2323
from mlagents.envs.subprocess_env_manager import SubprocessEnvManager
2424

@@ -323,22 +323,6 @@ def prepare_for_docker_run(docker_target_name, env_path):
323323
return env_path
324324

325325

326-
def load_config(trainer_config_path: str) -> Dict[str, Any]:
327-
try:
328-
with open(trainer_config_path) as data_file:
329-
trainer_config = yaml.safe_load(data_file)
330-
return trainer_config
331-
except IOError:
332-
raise UnityEnvironmentException(
333-
"Parameter file could not be found " "at {}.".format(trainer_config_path)
334-
)
335-
except UnicodeDecodeError:
336-
raise UnityEnvironmentException(
337-
"There was an error decoding "
338-
"Trainer Config from this path : {}".format(trainer_config_path)
339-
)
340-
341-
342326
def create_environment_factory(
343327
env_path: str,
344328
docker_target_name: Optional[str],

ml-agents/mlagents/trainers/tests/test_curriculum.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import io
2+
import json
13
import pytest
24
from unittest.mock import patch, mock_open
35

4-
from mlagents.trainers.exception import CurriculumError
6+
from mlagents.trainers.exception import CurriculumConfigError, CurriculumLoadingError
57
from mlagents.trainers.curriculum import Curriculum
68

79

@@ -60,7 +62,7 @@ def test_init_curriculum_happy_path(mock_file, location, default_reset_parameter
6062
def test_init_curriculum_bad_curriculum_raises_error(
6163
mock_file, location, default_reset_parameters
6264
):
63-
with pytest.raises(CurriculumError):
65+
with pytest.raises(CurriculumConfigError):
6466
Curriculum(location, default_reset_parameters)
6567

6668

@@ -93,3 +95,30 @@ def test_get_config(mock_file):
9395
curriculum.lesson_num = 2
9496
assert curriculum.get_config() == {"param1": 0.3, "param2": 20, "param3": 0.7}
9597
assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2}
98+
99+
100+
# Test json loading and error handling. These examples don't need to valid config files.
101+
102+
103+
def test_curriculum_load_good():
104+
expected = {"x": 1}
105+
value = json.dumps(expected)
106+
fp = io.StringIO(value)
107+
assert expected == Curriculum._load_curriculum(fp)
108+
109+
110+
def test_curriculum_load_missing_file():
111+
with pytest.raises(CurriculumLoadingError):
112+
Curriculum.load_curriculum_file("notAValidFile.json")
113+
114+
115+
def test_curriculum_load_invalid_json():
116+
# This isn't valid json because of the trailing comma
117+
contents = """
118+
{
119+
"x": [1, 2, 3,]
120+
}
121+
"""
122+
fp = io.StringIO(contents)
123+
with pytest.raises(CurriculumLoadingError):
124+
Curriculum._load_curriculum(fp)

ml-agents/mlagents/trainers/tests/test_sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@pytest.fixture
1717
def dummy_config():
18-
return yaml.load(
18+
return yaml.safe_load(
1919
"""
2020
trainer: sac
2121
batch_size: 32

ml-agents/mlagents/trainers/tests/test_trainer_util.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pytest
22
import yaml
33
import os
4+
import io
45
from unittest.mock import patch
56

67
import mlagents.trainers.trainer_util as trainer_util
8+
from mlagents.trainers.trainer_util import load_config, _load_config
79
from mlagents.trainers.trainer_metrics import TrainerMetrics
810
from mlagents.trainers.ppo.trainer import PPOTrainer
911
from mlagents.trainers.bc.offline_trainer import OfflineBCTrainer
@@ -313,3 +315,30 @@ def test_initialize_invalid_trainer_raises_exception(BrainParametersMock):
313315
load_model=load_model,
314316
seed=seed,
315317
)
318+
319+
320+
def test_load_config_missing_file():
321+
with pytest.raises(UnityEnvironmentException):
322+
load_config("thisFileDefinitelyDoesNotExist.yaml")
323+
324+
325+
def test_load_config_valid_yaml():
326+
file_contents = """
327+
this:
328+
- is fine
329+
"""
330+
fp = io.StringIO(file_contents)
331+
res = _load_config(fp)
332+
assert res == {"this": ["is fine"]}
333+
334+
335+
def test_load_config_invalid_yaml():
336+
file_contents = """
337+
you:
338+
- will
339+
- not
340+
- parse
341+
"""
342+
with pytest.raises(UnityEnvironmentException):
343+
fp = io.StringIO(file_contents)
344+
_load_config(fp)

ml-agents/mlagents/trainers/trainer_util.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict
1+
import yaml
2+
from typing import Any, Dict, TextIO
23

34
from mlagents.trainers.meta_curriculum import MetaCurriculum
45
from mlagents.envs.exception import UnityEnvironmentException
@@ -108,3 +109,31 @@ def initialize_trainers(
108109
"brain {}".format(brain_name)
109110
)
110111
return trainers
112+
113+
114+
def load_config(config_path: str) -> Dict[str, Any]:
115+
try:
116+
with open(config_path) as data_file:
117+
return _load_config(data_file)
118+
except IOError:
119+
raise UnityEnvironmentException(
120+
f"Config file could not be found at {config_path}."
121+
)
122+
except UnicodeDecodeError:
123+
raise UnityEnvironmentException(
124+
f"There was an error decoding Config file from {config_path}. "
125+
f"Make sure your file is save using UTF-8"
126+
)
127+
128+
129+
def _load_config(fp: TextIO) -> Dict[str, Any]:
130+
"""
131+
Load the yaml config from the file-like object.
132+
"""
133+
try:
134+
return yaml.safe_load(fp)
135+
except yaml.parser.ParserError as e:
136+
raise UnityEnvironmentException(
137+
"Error parsing yaml file. Please check for formatting errors. "
138+
"A tool such as http://www.yamllint.com/ can be helpful with this."
139+
) from e

0 commit comments

Comments
 (0)