diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fb6b9f281..52fb1916f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -19,6 +19,7 @@ Bug Fixes: - Update env checker to warn users when using Graph space (@dhruvmalik007). - Fixed memory leak in ``VecVideoRecorder`` where ``recorded_frames`` stayed in memory due to reference in the moviepy clip (@copilot) - Remove double space in `StopTrainingOnRewardThreshold` callback message (@sea-bass) +- Add close method to BaseAlgorithm to prevent memory leaks in sequential training loops (#1966) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index f2c205166..e4fbd00f1 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -1,5 +1,6 @@ """Abstract base classes for RL algorithms.""" +import gc import io import pathlib import time @@ -866,6 +867,30 @@ def save( save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables) + def close(self) -> None: + """ + Clean up resources after training or prediction to prevent memory leaks + when calling :meth:`learn()` repeatedly with new environments. + + Fixes https://github.com/DLR-RM/stable-baselines3/issues/1996 + """ + if self.env is not None: + self.env.close() + self.env = None + + if hasattr(self, "rollout_buffer") and self.rollout_buffer is not None: + del self.rollout_buffer + self.rollout_buffer = None + + if hasattr(self, "policy") and self.policy is not None: + del self.policy + self.policy = None + + if self.device.type == "cuda": + th.cuda.empty_cache() + + gc.collect() + def dump_logs(self) -> None: """ Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)