Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Bug Fixes:
- Update env checker to warn users when using Graph space (@dhruvmalik007).
- Fixed memory leak in ``VecVideoRecorder`` where ``recorded_frames`` stayed in memory due to reference in the moviepy clip (@copilot)
- Remove double space in `StopTrainingOnRewardThreshold` callback message (@sea-bass)
- Add close method to BaseAlgorithm to prevent memory leaks in sequential training loops (#1966)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
25 changes: 25 additions & 0 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Abstract base classes for RL algorithms."""

import gc
import io
import pathlib
import time
Expand Down Expand Up @@ -866,6 +867,30 @@ def save(

save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)

def close(self) -> None:
"""
Clean up resources after training or prediction to prevent memory leaks
when calling :meth:`learn()` repeatedly with new environments.

Fixes https://github.com/DLR-RM/stable-baselines3/issues/1996
"""
if self.env is not None:
self.env.close()
self.env = None

if hasattr(self, "rollout_buffer") and self.rollout_buffer is not None:
del self.rollout_buffer
self.rollout_buffer = None

if hasattr(self, "policy") and self.policy is not None:
del self.policy
self.policy = None

if self.device.type == "cuda":
th.cuda.empty_cache()

gc.collect()

def dump_logs(self) -> None:
"""
Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)
Expand Down