@@ -231,7 +231,7 @@ def update(self) -> dict[str, float]:
231231 # Check if we should normalize advantages per mini batch
232232 if self .normalize_advantage_per_mini_batch :
233233 with torch .no_grad ():
234- batch .advantages = (batch .advantages - batch .advantages .mean ()) / (batch .advantages .std () + 1e-8 )
234+ batch .advantages = (batch .advantages - batch .advantages .mean ()) / (batch .advantages .std () + 1e-8 ) # type: ignore
235235
236236 # Perform symmetric augmentation
237237 if self .symmetry and self .symmetry ["use_data_augmentation" ]:
@@ -259,7 +259,7 @@ def update(self) -> dict[str, float]:
259259 hidden_state = batch .hidden_states [0 ],
260260 stochastic_output = True ,
261261 )
262- actions_log_prob = self .actor .get_output_log_prob (batch .actions )
262+ actions_log_prob = self .actor .get_output_log_prob (batch .actions ) # type: ignore
263263 values = self .critic (batch .observations , masks = batch .masks , hidden_state = batch .hidden_states [1 ])
264264 # Note: We only keep the distribution parameters and entropy of the first augmentation (the original one)
265265 distribution_params = tuple (p [:original_batch_size ] for p in self .actor .output_distribution_params )
@@ -268,7 +268,7 @@ def update(self) -> dict[str, float]:
268268 # Compute KL divergence and adapt the learning rate
269269 if self .desired_kl is not None and self .schedule == "adaptive" :
270270 with torch .inference_mode ():
271- kl = self .actor .get_kl_divergence (batch .old_distribution_params , distribution_params )
271+ kl = self .actor .get_kl_divergence (batch .old_distribution_params , distribution_params ) # type: ignore
272272 kl_mean = torch .mean (kl )
273273
274274 # Reduce the KL divergence across all GPUs
@@ -294,9 +294,9 @@ def update(self) -> dict[str, float]:
294294 param_group ["lr" ] = self .learning_rate
295295
296296 # Surrogate loss
297- ratio = torch .exp (actions_log_prob - torch .squeeze (batch .old_actions_log_prob ))
298- surrogate = - torch .squeeze (batch .advantages ) * ratio
299- surrogate_clipped = - torch .squeeze (batch .advantages ) * torch .clamp (
297+ ratio = torch .exp (actions_log_prob - torch .squeeze (batch .old_actions_log_prob )) # type: ignore
298+ surrogate = - torch .squeeze (batch .advantages ) * ratio # type: ignore
299+ surrogate_clipped = - torch .squeeze (batch .advantages ) * torch .clamp ( # type: ignore
300300 ratio , 1.0 - self .clip_param , 1.0 + self .clip_param
301301 )
302302 surrogate_loss = torch .max (surrogate , surrogate_clipped ).mean ()
0 commit comments