Skip to content

Commit 1685c43

Browse files
Chris ElionErvin Teng
authored andcommitted
small mypy cleanup (#2637)
* small mypy cleanup * sac cleanup * types for ppo policy init
1 parent 7f1cabe commit 1685c43

File tree

5 files changed

+23
-26
lines changed

5 files changed

+23
-26
lines changed

ml-agents/mlagents/trainers/ppo/policy.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22
import numpy as np
3-
from typing import Any, Dict
3+
from typing import Any, Dict, Optional
44
import tensorflow as tf
55

66
from mlagents.envs.timers import timed
7-
from mlagents.envs.brain import BrainInfo
7+
from mlagents.envs.brain import BrainInfo, BrainParameters
88
from mlagents.trainers.models import EncoderType, LearningRateSchedule
99
from mlagents.trainers.ppo.models import PPOModel
1010
from mlagents.trainers.tf_policy import TFPolicy
@@ -17,7 +17,14 @@
1717

1818

1919
class PPOPolicy(TFPolicy):
20-
def __init__(self, seed, brain, trainer_params, is_training, load):
20+
def __init__(
21+
self,
22+
seed: int,
23+
brain: BrainParameters,
24+
trainer_params: Dict[str, Any],
25+
is_training: bool,
26+
load: bool,
27+
):
2128
"""
2229
Policy for Proximal Policy Optimization Networks.
2330
:param seed: Random seed.
@@ -29,8 +36,8 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
2936
super().__init__(seed, brain, trainer_params)
3037

3138
reward_signal_configs = trainer_params["reward_signals"]
32-
self.inference_dict = {}
33-
self.update_dict = {}
39+
self.inference_dict: Dict[str, tf.Tensor] = {}
40+
self.update_dict: Dict[str, tf.Tensor] = {}
3441
self.stats_name_to_update_name = {
3542
"Losses/Value Loss": "value_loss",
3643
"Losses/Policy Loss": "policy_loss",
@@ -42,6 +49,7 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
4249
self.create_reward_signals(reward_signal_configs)
4350

4451
with self.graph.as_default():
52+
self.bc_module: Optional[BCModule] = None
4553
# Create pretrainer if needed
4654
if "pretraining" in trainer_params:
4755
BCModule.check_config(trainer_params["pretraining"])
@@ -52,8 +60,6 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
5260
default_num_epoch=trainer_params["num_epoch"],
5361
**trainer_params["pretraining"],
5462
)
55-
else:
56-
self.bc_module = None
5763

5864
if load:
5965
self._load_graph()

ml-agents/mlagents/trainers/sac/policy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Dict, Any
2+
from typing import Dict, Any, Optional
33
import numpy as np
44
import tensorflow as tf
55

@@ -58,6 +58,7 @@ def __init__(
5858

5959
with self.graph.as_default():
6060
# Create pretrainer if needed
61+
self.bc_module: Optional[BCModule] = None
6162
if "pretraining" in trainer_params:
6263
BCModule.check_config(trainer_params["pretraining"])
6364
self.bc_module = BCModule(
@@ -74,8 +75,6 @@ def __init__(
7475
"Pretraining: Samples Per Update is not a valid setting for SAC."
7576
)
7677
self.bc_module.samples_per_update = 1
77-
else:
78-
self.bc_module = None
7978

8079
if load:
8180
self._load_graph()

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from mlagents.envs.brain import AllBrainInfo
1414
from mlagents.envs.action_info import ActionInfoOutputs
1515
from mlagents.envs.timers import timed
16-
from mlagents.trainers.buffer import Buffer
1716
from mlagents.trainers.sac.policy import SACPolicy
1817
from mlagents.trainers.rl_trainer import RLTrainer, AllRewardsOutput
1918

@@ -121,7 +120,7 @@ def save_replay_buffer(self) -> None:
121120
with open(filename, "wb") as file_object:
122121
self.training_buffer.update_buffer.save_to_file(file_object)
123122

124-
def load_replay_buffer(self) -> Buffer:
123+
def load_replay_buffer(self) -> None:
125124
"""
126125
Loads the last saved replay buffer from a file.
127126
"""

ml-agents/mlagents/trainers/trainer_metrics.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,12 @@ def end_policy_update(self) -> None:
102102
self.delta_policy_update = 0
103103
delta_train_start = time() - self.time_training_start
104104
LOGGER.debug(
105-
" Policy Update Training Metrics for {}: "
106-
"\n\t\tTime to update Policy: {:0.3f} s \n"
107-
"\t\tTime elapsed since training: {:0.3f} s \n"
108-
"\t\tTime for experience collection: {:0.3f} s \n"
109-
"\t\tBuffer Length: {} \n"
110-
"\t\tReturns : {:0.3f}\n".format(
111-
self.brain_name,
112-
self.delta_policy_update,
113-
delta_train_start,
114-
self.delta_last_experience_collection,
115-
self.last_buffer_length,
116-
self.last_mean_return,
117-
)
105+
f" Policy Update Training Metrics for {self.brain_name}: "
106+
f"\n\t\tTime to update Policy: {self.delta_policy_update:0.3f} s \n"
107+
f"\t\tTime elapsed since training: {delta_train_start:0.3f} s \n"
108+
f"\t\tTime for experience collection: {(self.delta_last_experience_collection or 0):0.3f} s \n"
109+
f"\t\tBuffer Length: {self.last_buffer_length or 0} \n"
110+
f"\t\tReturns : {(self.last_mean_return or 0):0.3f}\n"
118111
)
119112
self._add_row(delta_train_start)
120113

ml-agents/mlagents/trainers/trainer_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def initialize_trainers(
4141
:param multi_gpu: Whether to use multi-GPU training
4242
:return:
4343
"""
44-
trainers = {}
44+
trainers: Dict[str, Trainer] = {}
4545
trainer_parameters_dict = {}
4646
for brain_name in external_brains:
4747
trainer_parameters = trainer_config["default"].copy()

0 commit comments

Comments
 (0)