@@ -30,7 +30,7 @@ def __init__(
3030 init_noise_std : float = 1.0 ,
3131 noise_std_type : str = "scalar" ,
3232 ** kwargs ,
33- ):
33+ ) -> None :
3434 if kwargs :
3535 print (
3636 "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: "
@@ -170,12 +170,10 @@ def __init__(
170170 # disable args validation for speedup
171171 Normal .set_default_validate_args (False )
172172
173- def update_distribution (self , mlp_obs : torch .Tensor , cnn_obs : dict [str , torch .Tensor ]):
173+ def update_distribution (self , mlp_obs : torch .Tensor , cnn_obs : dict [str , torch .Tensor ]) -> None :
174174 if self .actor_cnns is not None :
175175 # encode the 2D actor observations
176- cnn_enc_list = []
177- for obs_group in self .actor_obs_group_2d :
178- cnn_enc_list .append (self .actor_cnns [obs_group ](cnn_obs [obs_group ]))
176+ cnn_enc_list = [self .actor_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .actor_obs_group_2d ]
179177 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
180178 # update mlp obs
181179 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
@@ -194,9 +192,7 @@ def act_inference(self, obs):
194192
195193 if self .actor_cnns is not None :
196194 # encode the 2D actor observations
197- cnn_enc_list = []
198- for obs_group in self .actor_obs_group_2d :
199- cnn_enc_list .append (self .actor_cnns [obs_group ](cnn_obs [obs_group ]))
195+ cnn_enc_list = [self .actor_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .actor_obs_group_2d ]
200196 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
201197 # update mlp obs
202198 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
@@ -209,34 +205,28 @@ def evaluate(self, obs, **kwargs):
209205
210206 if self .critic_cnns is not None :
211207 # encode the 2D critic observations
212- cnn_enc_list = []
213- for obs_group in self .critic_obs_group_2d :
214- cnn_enc_list .append (self .critic_cnns [obs_group ](cnn_obs [obs_group ]))
208+ cnn_enc_list = [self .critic_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .critic_obs_group_2d ]
215209 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
216210 # update mlp obs
217211 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
218212
219213 return self .critic (mlp_obs )
220214
221215 def get_actor_obs (self , obs ):
222- obs_list_1d = []
223216 obs_dict_2d = {}
224- for obs_group in self .actor_obs_group_1d :
225- obs_list_1d .append (obs [obs_group ])
217+ obs_list_1d = [obs [obs_group ] for obs_group in self .actor_obs_group_1d ]
226218 for obs_group in self .actor_obs_group_2d :
227219 obs_dict_2d [obs_group ] = obs [obs_group ]
228220 return torch .cat (obs_list_1d , dim = - 1 ), obs_dict_2d
229221
230222 def get_critic_obs (self , obs ):
231- obs_list_1d = []
232223 obs_dict_2d = {}
233- for obs_group in self .critic_obs_group_1d :
234- obs_list_1d .append (obs [obs_group ])
224+ obs_list_1d = [obs [obs_group ] for obs_group in self .critic_obs_group_1d ]
235225 for obs_group in self .critic_obs_group_2d :
236226 obs_dict_2d [obs_group ] = obs [obs_group ]
237227 return torch .cat (obs_list_1d , dim = - 1 ), obs_dict_2d
238228
239- def update_normalization (self , obs ):
229+ def update_normalization (self , obs ) -> None :
240230 if self .actor_obs_normalization :
241231 actor_obs , _ = self .get_actor_obs (obs )
242232 self .actor_obs_normalizer .update (actor_obs )
0 commit comments