Skip to content

Commit b01d51b

Browse files
taufeeque9ernestum
authored andcommitted
Fix test errors
1 parent 0628772 commit b01d51b

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/imitation/algorithms/adversarial/common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def _on_step(self) -> bool:
115115
return True
116116

117117
def _on_rollout_end(self) -> None:
118+
if self.gen_ctx_manager is not None:
119+
self.exit_gen_ctx_manager()
118120
gen_trajs, ep_lens = self.adversarial_trainer.venv_buffering.pop_trajectories()
119121
self.adversarial_trainer._check_fixed_horizon(ep_lens)
120122
gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs)
@@ -133,9 +135,13 @@ def _on_rollout_end(self) -> None:
133135
self.gen_ctx_manager = self.adversarial_trainer.logger.accumulate_means("gen")
134136
self.gen_ctx_manager.__enter__()
135137

136-
def _on_training_end(self) -> None:
138+
def exit_gen_ctx_manager(self) -> None:
137139
assert self.gen_ctx_manager is not None
138140
self.gen_ctx_manager.__exit__(None, None, None)
141+
self.gen_ctx_manager = None
142+
143+
def _on_training_end(self) -> None:
144+
self.exit_gen_ctx_manager()
139145

140146

141147
class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]):
@@ -514,8 +520,8 @@ def train(
514520
) -> None:
515521
"""Alternates between training the generator and discriminator.
516522
517-
Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
518-
a call to `train_disc`, and finally a call to `callback(round)`.
523+
Every "round" consists of a call to
524+
`train_gen_with_disc(self.gen_train_timesteps)` and a call to `callback(round)`.
519525
520526
Training ends once an additional "round" would cause the number of transitions
521527
sampled from the environment to exceed `total_timesteps`.

tests/algorithms/test_adversarial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_train_gen_train_disc_no_crash(
232232
n_updates: int = 2,
233233
) -> None:
234234
trainer_parametrized.train_gen_with_disc(
235-
n_updates * trainer_parametrized.gen_train_timesteps
235+
n_updates * trainer_parametrized.gen_train_timesteps,
236236
)
237237

238238

0 commit comments

Comments
 (0)