@@ -30,27 +30,10 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
3030
3131 reward_signal_configs = trainer_params ["reward_signals" ]
3232
33+ self .create_model (brain , trainer_params , reward_signal_configs , seed )
34+
3335 self .reward_signals = {}
3436 with self .graph .as_default ():
35- self .model = PPOModel (
36- brain ,
37- lr = float (trainer_params ["learning_rate" ]),
38- h_size = int (trainer_params ["hidden_units" ]),
39- epsilon = float (trainer_params ["epsilon" ]),
40- beta = float (trainer_params ["beta" ]),
41- max_step = float (trainer_params ["max_steps" ]),
42- normalize = trainer_params ["normalize" ],
43- use_recurrent = trainer_params ["use_recurrent" ],
44- num_layers = int (trainer_params ["num_layers" ]),
45- m_size = self .m_size ,
46- seed = seed ,
47- stream_names = list (reward_signal_configs .keys ()),
48- vis_encode_type = EncoderType (
49- trainer_params .get ("vis_encode_type" , "simple" )
50- ),
51- )
52- self .model .create_ppo_optimizer ()
53-
5437 # Create reward signals
5538 for reward_signal , config in reward_signal_configs .items ():
5639 self .reward_signals [reward_signal ] = create_reward_signal (
@@ -102,6 +85,34 @@ def __init__(self, seed, brain, trainer_params, is_training, load):
10285 "update_batch" : self .model .update_batch ,
10386 }
10487
88+ def create_model (self , brain , trainer_params , reward_signal_configs , seed ):
89+ """
90+ Create PPO model
91+ :param brain: Assigned Brain object.
92+ :param trainer_params: Defined training parameters.
93+ :param reward_signal_configs: Reward signal config
94+ :param seed: Random seed.
95+ """
96+ with self .graph .as_default ():
97+ self .model = PPOModel (
98+ brain = brain ,
99+ lr = float (trainer_params ["learning_rate" ]),
100+ h_size = int (trainer_params ["hidden_units" ]),
101+ epsilon = float (trainer_params ["epsilon" ]),
102+ beta = float (trainer_params ["beta" ]),
103+ max_step = float (trainer_params ["max_steps" ]),
104+ normalize = trainer_params ["normalize" ],
105+ use_recurrent = trainer_params ["use_recurrent" ],
106+ num_layers = int (trainer_params ["num_layers" ]),
107+ m_size = self .m_size ,
108+ seed = seed ,
109+ stream_names = list (reward_signal_configs .keys ()),
110+ vis_encode_type = EncoderType (
111+ trainer_params .get ("vis_encode_type" , "simple" )
112+ ),
113+ )
114+ self .model .create_ppo_optimizer ()
115+
105116 @timed
106117 def evaluate (self , brain_info ):
107118 """
@@ -143,58 +154,62 @@ def update(self, mini_batch, num_sequences):
143154 :param mini_batch: Experience batch.
144155 :return: Output from update process.
145156 """
157+ feed_dict = self .construct_feed_dict (self .model , mini_batch , num_sequences )
158+ run_out = self ._execute_model (feed_dict , self .update_dict )
159+ return run_out
160+
161+ def construct_feed_dict (self , model , mini_batch , num_sequences ):
146162 feed_dict = {
147- self . model .batch_size : num_sequences ,
148- self . model .sequence_length : self .sequence_length ,
149- self . model .mask_input : mini_batch ["masks" ].flatten (),
150- self . model .advantage : mini_batch ["advantages" ].reshape ([- 1 , 1 ]),
151- self . model .all_old_log_probs : mini_batch ["action_probs" ].reshape (
152- [- 1 , sum (self . model .act_size )]
163+ model .batch_size : num_sequences ,
164+ model .sequence_length : self .sequence_length ,
165+ model .mask_input : mini_batch ["masks" ].flatten (),
166+ model .advantage : mini_batch ["advantages" ].reshape ([- 1 , 1 ]),
167+ model .all_old_log_probs : mini_batch ["action_probs" ].reshape (
168+ [- 1 , sum (model .act_size )]
153169 ),
154170 }
155171 for name in self .reward_signals :
156- feed_dict [self . model .returns_holders [name ]] = mini_batch [
172+ feed_dict [model .returns_holders [name ]] = mini_batch [
157173 "{}_returns" .format (name )
158174 ].flatten ()
159- feed_dict [self . model .old_values [name ]] = mini_batch [
175+ feed_dict [model .old_values [name ]] = mini_batch [
160176 "{}_value_estimates" .format (name )
161177 ].flatten ()
162178
163179 if self .use_continuous_act :
164- feed_dict [self . model .output_pre ] = mini_batch ["actions_pre" ].reshape (
165- [- 1 , self . model .act_size [0 ]]
180+ feed_dict [model .output_pre ] = mini_batch ["actions_pre" ].reshape (
181+ [- 1 , model .act_size [0 ]]
166182 )
167- feed_dict [self . model .epsilon ] = mini_batch ["random_normal_epsilon" ].reshape (
168- [- 1 , self . model .act_size [0 ]]
183+ feed_dict [model .epsilon ] = mini_batch ["random_normal_epsilon" ].reshape (
184+ [- 1 , model .act_size [0 ]]
169185 )
170186 else :
171- feed_dict [self . model .action_holder ] = mini_batch ["actions" ].reshape (
172- [- 1 , len (self . model .act_size )]
187+ feed_dict [model .action_holder ] = mini_batch ["actions" ].reshape (
188+ [- 1 , len (model .act_size )]
173189 )
174190 if self .use_recurrent :
175- feed_dict [self . model .prev_action ] = mini_batch ["prev_action" ].reshape (
176- [- 1 , len (self . model .act_size )]
191+ feed_dict [model .prev_action ] = mini_batch ["prev_action" ].reshape (
192+ [- 1 , len (model .act_size )]
177193 )
178- feed_dict [self . model .action_masks ] = mini_batch ["action_mask" ].reshape (
194+ feed_dict [model .action_masks ] = mini_batch ["action_mask" ].reshape (
179195 [- 1 , sum (self .brain .vector_action_space_size )]
180196 )
181197 if self .use_vec_obs :
182- feed_dict [self . model .vector_in ] = mini_batch ["vector_obs" ].reshape (
198+ feed_dict [model .vector_in ] = mini_batch ["vector_obs" ].reshape (
183199 [- 1 , self .vec_obs_size ]
184200 )
185- if self . model .vis_obs_size > 0 :
186- for i , _ in enumerate (self . model .visual_in ):
201+ if model .vis_obs_size > 0 :
202+ for i , _ in enumerate (model .visual_in ):
187203 _obs = mini_batch ["visual_obs%d" % i ]
188204 if self .sequence_length > 1 and self .use_recurrent :
189205 (_batch , _seq , _w , _h , _c ) = _obs .shape
190- feed_dict [self . model .visual_in [i ]] = _obs .reshape ([- 1 , _w , _h , _c ])
206+ feed_dict [model .visual_in [i ]] = _obs .reshape ([- 1 , _w , _h , _c ])
191207 else :
192- feed_dict [self . model .visual_in [i ]] = _obs
208+ feed_dict [model .visual_in [i ]] = _obs
193209 if self .use_recurrent :
194210 mem_in = mini_batch ["memory" ][:, 0 , :]
195- feed_dict [self .model .memory_in ] = mem_in
196- run_out = self ._execute_model (feed_dict , self .update_dict )
197- return run_out
211+ feed_dict [model .memory_in ] = mem_in
212+ return feed_dict
198213
199214 def get_value_estimates (
200215 self , brain_info : BrainInfo , idx : int , done : bool
0 commit comments