@@ -38,6 +38,7 @@ def __init__(
3838 schedule = "fixed" ,
3939 desired_kl = 0.01 ,
4040 device = "cpu" ,
41+ normalize_advantage_per_mini_batch = False ,
4142 # RND parameters
4243 rnd_cfg : dict | None = None ,
4344 # Symmetry parameters
@@ -48,6 +49,7 @@ def __init__(
4849 self .desired_kl = desired_kl
4950 self .schedule = schedule
5051 self .learning_rate = learning_rate
52+ self .normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
5153
5254 # RND components
5355 if rnd_cfg is not None :
@@ -84,8 +86,10 @@ def __init__(
8486 # PPO components
8587 self .actor_critic = actor_critic
8688 self .actor_critic .to (self .device )
87- self . storage = None # initialized later
89+ # Create optimizer
8890 self .optimizer = optim .Adam (self .actor_critic .parameters (), lr = learning_rate )
91+ # Create rollout storage
92+ self .storage : RolloutStorage = None # type: ignore
8993 self .transition = RolloutStorage .Transition ()
9094
9195 # PPO parameters
@@ -168,7 +172,9 @@ def process_env_step(self, rewards, dones, infos):
168172 def compute_returns (self , last_critic_obs ):
169173 # compute value for the last step
170174 last_values = self .actor_critic .evaluate (last_critic_obs ).detach ()
171- self .storage .compute_returns (last_values , self .gamma , self .lam )
175+ self .storage .compute_returns (
176+ last_values , self .gamma , self .lam , normalize_advantage = not self .normalize_advantage_per_mini_batch
177+ )
172178
173179 def update (self ): # noqa: C901
174180 mean_value_loss = 0
@@ -213,6 +219,11 @@ def update(self): # noqa: C901
213219 # original batch size
214220 original_batch_size = obs_batch .shape [0 ]
215221
222+ # check if we should normalize advantages per mini batch
223+ if self .normalize_advantage_per_mini_batch :
224+ with torch .no_grad ():
225+ advantages_batch = (advantages_batch - advantages_batch .mean ()) / (advantages_batch .std () + 1e-8 )
226+
216227 # Perform symmetric augmentation
217228 if self .symmetry and self .symmetry ["use_data_augmentation" ]:
218229 # augmentation using symmetry
0 commit comments