@@ -115,6 +115,8 @@ def _on_step(self) -> bool:
115115 return True
116116
117117 def _on_rollout_end (self ) -> None :
118+ if self .gen_ctx_manager is not None :
119+ self .exit_gen_ctx_manager ()
118120 gen_trajs , ep_lens = self .adversarial_trainer .venv_buffering .pop_trajectories ()
119121 self .adversarial_trainer ._check_fixed_horizon (ep_lens )
120122 gen_samples = rollout .flatten_trajectories_with_rew (gen_trajs )
@@ -133,9 +135,13 @@ def _on_rollout_end(self) -> None:
133135 self .gen_ctx_manager = self .adversarial_trainer .logger .accumulate_means ("gen" )
134136 self .gen_ctx_manager .__enter__ ()
135137
136- def _on_training_end (self ) -> None :
138+ def exit_gen_ctx_manager (self ) -> None :
137139 assert self .gen_ctx_manager is not None
138140 self .gen_ctx_manager .__exit__ (None , None , None )
141+ self .gen_ctx_manager = None
142+
143+ def _on_training_end (self ) -> None :
144+ self .exit_gen_ctx_manager ()
139145
140146
141147class AdversarialTrainer (base .DemonstrationAlgorithm [types .Transitions ]):
@@ -514,8 +520,8 @@ def train(
514520 ) -> None :
515521 """Alternates between training the generator and discriminator.
516522
517- Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
518- a call to `train_disc`, and finally a call to `callback(round)`.
523+ Every "round" consists of a call to
524+ `train_gen_with_disc(self.gen_train_timesteps)` and a call to `callback(round)`.
519525
520526 Training ends once an additional "round" would cause the number of transitions
521527 sampled from the environment to exceed `total_timesteps`.
0 commit comments