11import logging
22import numpy as np
3- from typing import Any , Dict
3+ from typing import Any , Dict , Optional
44import tensorflow as tf
55
66from mlagents .envs .timers import timed
7- from mlagents .envs .brain import BrainInfo
7+ from mlagents .envs .brain import BrainInfo , BrainParameters
88from mlagents .trainers .models import EncoderType , LearningRateSchedule
99from mlagents .trainers .ppo .models import PPOModel
1010from mlagents .trainers .tf_policy import TFPolicy
1717
1818
1919class PPOPolicy (TFPolicy ):
20- def __init__ (self , seed , brain , trainer_params , is_training , load ):
20+ def __init__ (
21+ self ,
22+ seed : int ,
23+ brain : BrainParameters ,
24+ trainer_params : Dict [str , Any ],
25+ is_training : bool ,
26+ load : bool ,
27+ ):
2128 """
2229 Policy for Proximal Policy Optimization Networks.
2330 :param seed: Random seed.
@@ -29,8 +36,8 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
2936 super ().__init__ (seed , brain , trainer_params )
3037
3138 reward_signal_configs = trainer_params ["reward_signals" ]
32- self .inference_dict = {}
33- self .update_dict = {}
39+ self .inference_dict : Dict [ str , tf . Tensor ] = {}
40+ self .update_dict : Dict [ str , tf . Tensor ] = {}
3441 self .stats_name_to_update_name = {
3542 "Losses/Value Loss" : "value_loss" ,
3643 "Losses/Policy Loss" : "policy_loss" ,
@@ -42,6 +49,7 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
4249 self .create_reward_signals (reward_signal_configs )
4350
4451 with self .graph .as_default ():
52+ self .bc_module : Optional [BCModule ] = None
4553 # Create pretrainer if needed
4654 if "pretraining" in trainer_params :
4755 BCModule .check_config (trainer_params ["pretraining" ])
@@ -52,8 +60,6 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
5260 default_num_epoch = trainer_params ["num_epoch" ],
5361 ** trainer_params ["pretraining" ],
5462 )
55- else :
56- self .bc_module = None
5763
5864 if load :
5965 self ._load_graph ()
0 commit comments