1919class PPO :
2020 """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347)."""
2121
22- actor_critic : ActorCritic
22+ policy : ActorCritic
2323 """The actor critic module."""
2424
2525 def __init__ (
2626 self ,
27- actor_critic ,
27+ policy ,
2828 num_learning_epochs = 1 ,
2929 num_mini_batches = 1 ,
3030 clip_param = 0.2 ,
@@ -84,10 +84,10 @@ def __init__(
8484 self .symmetry = None
8585
8686 # PPO components
87- self .actor_critic = actor_critic
88- self .actor_critic .to (self .device )
87+ self .policy = policy
88+ self .policy .to (self .device )
8989 # Create optimizer
90- self .optimizer = optim .Adam (self .actor_critic .parameters (), lr = learning_rate )
90+ self .optimizer = optim .Adam (self .policy .parameters (), lr = learning_rate )
9191 # Create rollout storage
9292 self .storage : RolloutStorage = None # type: ignore
9393 self .transition = RolloutStorage .Transition ()
@@ -103,41 +103,38 @@ def __init__(
103103 self .max_grad_norm = max_grad_norm
104104 self .use_clipped_value_loss = use_clipped_value_loss
105105
106- def init_storage (self , num_envs , num_transitions_per_env , actor_obs_shape , critic_obs_shape , action_shape ):
106+ def init_storage (
107+ self , training_type , num_envs , num_transitions_per_env , actor_obs_shape , critic_obs_shape , actions_shape
108+ ):
107109 # create memory for RND as well :)
108110 if self .rnd :
109111 rnd_state_shape = [self .rnd .num_states ]
110112 else :
111113 rnd_state_shape = None
112114 # create rollout storage
113115 self .storage = RolloutStorage (
116+ training_type ,
114117 num_envs ,
115118 num_transitions_per_env ,
116119 actor_obs_shape ,
117120 critic_obs_shape ,
118- action_shape ,
121+ actions_shape ,
119122 rnd_state_shape ,
120123 self .device ,
121124 )
122125
123- def test_mode (self ):
124- self .actor_critic .test ()
125-
126- def train_mode (self ):
127- self .actor_critic .train ()
128-
129126 def act (self , obs , critic_obs ):
130- if self .actor_critic .is_recurrent :
131- self .transition .hidden_states = self .actor_critic .get_hidden_states ()
132- # Compute the actions and values
133- self .transition .actions = self .actor_critic .act (obs ).detach ()
134- self .transition .values = self .actor_critic .evaluate (critic_obs ).detach ()
135- self .transition .actions_log_prob = self .actor_critic .get_actions_log_prob (self .transition .actions ).detach ()
136- self .transition .action_mean = self .actor_critic .action_mean .detach ()
137- self .transition .action_sigma = self .actor_critic .action_std .detach ()
127+ if self .policy .is_recurrent :
128+ self .transition .hidden_states = self .policy .get_hidden_states ()
129+ # compute the actions and values
130+ self .transition .actions = self .policy .act (obs ).detach ()
131+ self .transition .values = self .policy .evaluate (critic_obs ).detach ()
132+ self .transition .actions_log_prob = self .policy .get_actions_log_prob (self .transition .actions ).detach ()
133+ self .transition .action_mean = self .policy .action_mean .detach ()
134+ self .transition .action_sigma = self .policy .action_std .detach ()
138135 # need to record obs and critic_obs before env.step()
139136 self .transition .observations = obs
140- self .transition .critic_observations = critic_obs
137+ self .transition .privileged_observations = critic_obs
141138 return self .transition .actions
142139
143140 def process_env_step (self , rewards , dones , infos ):
@@ -164,14 +161,14 @@ def process_env_step(self, rewards, dones, infos):
164161 self .transition .values * infos ["time_outs" ].unsqueeze (1 ).to (self .device ), 1
165162 )
166163
167- # Record the transition
164+ # record the transition
168165 self .storage .add_transitions (self .transition )
169166 self .transition .clear ()
170- self .actor_critic .reset (dones )
167+ self .policy .reset (dones )
171168
172169 def compute_returns (self , last_critic_obs ):
173170 # compute value for the last step
174- last_values = self .actor_critic .evaluate (last_critic_obs ).detach ()
171+ last_values = self .policy .evaluate (last_critic_obs ).detach ()
175172 self .storage .compute_returns (
176173 last_values , self .gamma , self .lam , normalize_advantage = not self .normalize_advantage_per_mini_batch
177174 )
@@ -192,7 +189,7 @@ def update(self): # noqa: C901
192189 mean_symmetry_loss = None
193190
194191 # generator for mini batches
195- if self .actor_critic .is_recurrent :
192+ if self .policy .is_recurrent :
196193 generator = self .storage .recurrent_mini_batch_generator (self .num_mini_batches , self .num_learning_epochs )
197194 else :
198195 generator = self .storage .mini_batch_generator (self .num_mini_batches , self .num_learning_epochs )
@@ -230,10 +227,10 @@ def update(self): # noqa: C901
230227 data_augmentation_func = self .symmetry ["data_augmentation_func" ]
231228 # returned shape: [batch_size * num_aug, ...]
232229 obs_batch , actions_batch = data_augmentation_func (
233- obs = obs_batch , actions = actions_batch , env = self .symmetry ["_env" ], is_critic = False
230+ obs = obs_batch , actions = actions_batch , env = self .symmetry ["_env" ], obs_type = "policy"
234231 )
235232 critic_obs_batch , _ = data_augmentation_func (
236- obs = critic_obs_batch , actions = None , env = self .symmetry ["_env" ], is_critic = True
233+ obs = critic_obs_batch , actions = None , env = self .symmetry ["_env" ], obs_type = "critic"
237234 )
238235 # compute number of augmentations per sample
239236 num_aug = int (obs_batch .shape [0 ] / original_batch_size )
@@ -246,19 +243,17 @@ def update(self): # noqa: C901
246243 returns_batch = returns_batch .repeat (num_aug , 1 )
247244
248245 # Recompute actions log prob and entropy for current batch of transitions
249- # Note: we need to do this because we updated the actor_critic with the new parameters
246+ # Note: we need to do this because we updated the policy with the new parameters
250247 # -- actor
251- self .actor_critic .act (obs_batch , masks = masks_batch , hidden_states = hid_states_batch [0 ])
252- actions_log_prob_batch = self .actor_critic .get_actions_log_prob (actions_batch )
248+ self .policy .act (obs_batch , masks = masks_batch , hidden_states = hid_states_batch [0 ])
249+ actions_log_prob_batch = self .policy .get_actions_log_prob (actions_batch )
253250 # -- critic
254- value_batch = self .actor_critic .evaluate (
255- critic_obs_batch , masks = masks_batch , hidden_states = hid_states_batch [1 ]
256- )
251+ value_batch = self .policy .evaluate (critic_obs_batch , masks = masks_batch , hidden_states = hid_states_batch [1 ])
257252 # -- entropy
258253 # we only keep the entropy of the first augmentation (the original one)
259- mu_batch = self .actor_critic .action_mean [:original_batch_size ]
260- sigma_batch = self .actor_critic .action_std [:original_batch_size ]
261- entropy_batch = self .actor_critic .entropy [:original_batch_size ]
254+ mu_batch = self .policy .action_mean [:original_batch_size ]
255+ sigma_batch = self .policy .action_std [:original_batch_size ]
256+ entropy_batch = self .policy .entropy [:original_batch_size ]
262257
263258 # KL
264259 if self .desired_kl is not None and self .schedule == "adaptive" :
@@ -308,21 +303,21 @@ def update(self): # noqa: C901
308303 if not self .symmetry ["use_data_augmentation" ]:
309304 data_augmentation_func = self .symmetry ["data_augmentation_func" ]
310305 obs_batch , _ = data_augmentation_func (
311- obs = obs_batch , actions = None , env = self .symmetry ["_env" ], is_critic = False
306+ obs = obs_batch , actions = None , env = self .symmetry ["_env" ], obs_type = "policy"
312307 )
313308 # compute number of augmentations per sample
314309 num_aug = int (obs_batch .shape [0 ] / original_batch_size )
315310
316311 # actions predicted by the actor for symmetrically-augmented observations
317- mean_actions_batch = self .actor_critic .act_inference (obs_batch .detach ().clone ())
312+ mean_actions_batch = self .policy .act_inference (obs_batch .detach ().clone ())
318313
319314 # compute the symmetrically augmented actions
320315 # note: we are assuming the first augmentation is the original one.
321316 # We do not use the action_batch from earlier since that action was sampled from the distribution.
322317 # However, the symmetry loss is computed using the mean of the distribution.
323318 action_mean_orig = mean_actions_batch [:original_batch_size ]
324319 _ , actions_mean_symm_batch = data_augmentation_func (
325- obs = None , actions = action_mean_orig , env = self .symmetry ["_env" ], is_critic = False
320+ obs = None , actions = action_mean_orig , env = self .symmetry ["_env" ], obs_type = "policy"
326321 )
327322
328323 # compute the loss (we skip the first augmentation as it is the original one)
@@ -349,7 +344,7 @@ def update(self): # noqa: C901
349344 # -- For PPO
350345 self .optimizer .zero_grad ()
351346 loss .backward ()
352- nn .utils .clip_grad_norm_ (self .actor_critic .parameters (), self .max_grad_norm )
347+ nn .utils .clip_grad_norm_ (self .policy .parameters (), self .max_grad_norm )
353348 self .optimizer .step ()
354349 # -- For RND
355350 if self .rnd_optimizer :
@@ -382,4 +377,15 @@ def update(self): # noqa: C901
382377 # -- Clear the storage
383378 self .storage .clear ()
384379
385- return mean_value_loss , mean_surrogate_loss , mean_entropy , mean_rnd_loss , mean_symmetry_loss
380+ # construct the loss dictionary
381+ loss_dict = {
382+ "value_function" : mean_value_loss ,
383+ "surrogate" : mean_surrogate_loss ,
384+ "entropy" : mean_entropy ,
385+ }
386+ if self .rnd :
387+ loss_dict ["rnd" ] = mean_rnd_loss
388+ if self .symmetry :
389+ loss_dict ["symmetry" ] = mean_symmetry_loss
390+
391+ return loss_dict
0 commit comments