Skip to content

Commit b7d2b80

Browse files
[life improvement] Moving Python files around (#4531)
* Moved components to the tf folder and moved the TrainerFactory to the `trainer` folder * Addressing comments * Editing the migrating doc * fixing test
1 parent 20539c9 commit b7d2b80

31 files changed

+261
-230
lines changed

docs/Migrating.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,21 @@ double-check that the versions are in the same. The versions can be found in
1414

1515
# Migrating
1616

17-
## Migrating from Release 3 to latest
17+
## Migrating from Release 7 to latest
18+
19+
### Important changes
20+
- Some trainer files were moved. If you were using the `TrainerFactory` class, it was moved to
21+
the `trainers/trainer` folder.
22+
- The `components` folder containing `bc` and `reward_signals` code was moved to the `trainers/tf`
23+
folder
24+
25+
### Steps to Migrate
26+
- Replace calls to `from mlagents.trainers.trainer_util import TrainerFactory` to `from mlagents.trainers.trainer import TrainerFactory`
27+
- Replace calls to `from mlagents.trainers.trainer_util import handle_existing_directories` to `from mlagents.trainers.directory_utils import validate_existing_directories`
28+
- Replace `mlagents.trainers.components` with `mlagents.trainers.tf.components` in your import statements.
29+
30+
31+
## Migrating from Release 3 to Release 7
1832

1933
### Important changes
2034
- The Parameter Randomization feature has been merged with the Curriculum feature. It is now possible to specify a sampler
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import os
2+
from mlagents.trainers.exception import UnityTrainerException
3+
4+
5+
def validate_existing_directories(
6+
output_path: str, resume: bool, force: bool, init_path: str = None
7+
) -> None:
8+
"""
9+
Validates that if the run_id model exists, we do not overwrite it unless --force is specified.
10+
Throws an exception if resume isn't specified and run_id exists. Throws an exception
11+
if --resume is specified and run-id was not found.
12+
:param model_path: The model path specified.
13+
:param summary_path: The summary path to be used.
14+
:param resume: Whether or not the --resume flag was passed.
15+
:param force: Whether or not the --force flag was passed.
16+
"""
17+
18+
output_path_exists = os.path.isdir(output_path)
19+
20+
if output_path_exists:
21+
if not resume and not force:
22+
raise UnityTrainerException(
23+
"Previous data from this run ID was found. "
24+
"Either specify a new run ID, use --resume to resume this run, "
25+
"or use the --force parameter to overwrite existing data."
26+
)
27+
else:
28+
if resume:
29+
raise UnityTrainerException(
30+
"Previous data from this run ID was not found. "
31+
"Train a new run by removing the --resume flag."
32+
)
33+
34+
# Verify init path if specified.
35+
if init_path is not None:
36+
if not os.path.isdir(init_path):
37+
raise UnityTrainerException(
38+
"Could not initialize from {}. "
39+
"Make sure models have already been saved with that run ID.".format(
40+
init_path
41+
)
42+
)

ml-agents/mlagents/trainers/learn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from mlagents import tf_utils
1313
from mlagents.trainers.trainer_controller import TrainerController
1414
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
15-
from mlagents.trainers.trainer_util import TrainerFactory, handle_existing_directories
15+
from mlagents.trainers.trainer import TrainerFactory
16+
from mlagents.trainers.directory_utils import validate_existing_directories
1617
from mlagents.trainers.stats import (
1718
TensorboardWriter,
1819
StatsReporter,
@@ -75,7 +76,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
7576
run_logs_dir = os.path.join(write_path, "run_logs")
7677
port: Optional[int] = env_settings.base_port
7778
# Check if directory exists
78-
handle_existing_directories(
79+
validate_existing_directories(
7980
write_path,
8081
checkpoint_settings.resume,
8182
checkpoint_settings.force,

ml-agents/mlagents/trainers/optimizer/tf_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from mlagents.trainers.policy.tf_policy import TFPolicy
77
from mlagents.trainers.optimizer import Optimizer
88
from mlagents.trainers.trajectory import SplitObservations
9-
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
9+
from mlagents.trainers.tf.components.reward_signals.reward_signal_factory import (
1010
create_reward_signal,
1111
)
1212
from mlagents.trainers.settings import TrainerSettings, RewardSignalType
13-
from mlagents.trainers.components.bc.module import BCModule
13+
from mlagents.trainers.tf.components.bc.module import BCModule
1414

1515

1616
class TFOptimizer(Optimizer): # pylint: disable=W0223

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from mlagents.trainers.trajectory import Trajectory
1717
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
1818
from mlagents.trainers.settings import TrainerSettings, PPOSettings, FrameworkType
19-
from mlagents.trainers.components.reward_signals import RewardSignal
19+
from mlagents.trainers.tf.components.reward_signals import RewardSignal
2020
from mlagents import torch_utils
2121

2222
if torch_utils.is_available():

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mlagents.trainers.trajectory import Trajectory, SplitObservations
2020
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
2121
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType
22-
from mlagents.trainers.components.reward_signals import RewardSignal
22+
from mlagents.trainers.tf.components.reward_signals import RewardSignal
2323
from mlagents import torch_utils
2424

2525
if torch_utils.is_available():

ml-agents/mlagents/trainers/tests/check_env_trains.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from typing import Dict
55
from mlagents.trainers.trainer_controller import TrainerController
6-
from mlagents.trainers.trainer_util import TrainerFactory
6+
from mlagents.trainers.trainer import TrainerFactory
77
from mlagents.trainers.simple_env_manager import SimpleEnvManager
88
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary
99
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager

ml-agents/mlagents/trainers/tests/tensorflow/test_bcmodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from mlagents.trainers.policy.tf_policy import TFPolicy
7-
from mlagents.trainers.components.bc.module import BCModule
7+
from mlagents.trainers.tf.components.bc.module import BCModule
88
from mlagents.trainers.settings import (
99
TrainerSettings,
1010
BehavioralCloningSettings,

ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
RecordEnvironment,
1212
)
1313
from mlagents.trainers.trainer_controller import TrainerController
14-
from mlagents.trainers.trainer_util import TrainerFactory
14+
from mlagents.trainers.trainer import TrainerFactory
1515
from mlagents.trainers.simple_env_manager import SimpleEnvManager
1616
from mlagents.trainers.demo_loader import write_demo
1717
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary

ml-agents/mlagents/trainers/tests/test_learn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def basic_options(extra_args=None):
4747

4848
@patch("mlagents.trainers.learn.write_timing_tree")
4949
@patch("mlagents.trainers.learn.write_run_options")
50-
@patch("mlagents.trainers.learn.handle_existing_directories")
50+
@patch("mlagents.trainers.learn.validate_existing_directories")
5151
@patch("mlagents.trainers.learn.TrainerFactory")
5252
@patch("mlagents.trainers.learn.SubprocessEnvManager")
5353
@patch("mlagents.trainers.learn.create_environment_factory")

0 commit comments

Comments
 (0)