Skip to content

Commit 57e8b97

Browse files
authored
Fix video recorder and add test (#2063)
* Fix video recorder and add test * Update github CI * Install ffmpeg * Revert "Update github CI" This reverts commit 07791e9. * Skip VecVideoRecorder test on github
1 parent 0fd0db0 commit 57e8b97

File tree

4 files changed

+65
-14
lines changed

4 files changed

+65
-14
lines changed

docs/misc/changelog.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.5.0a0 (WIP)
6+
Release 2.5.0a1 (WIP)
77
--------------------------
88

99
Breaking Changes:
@@ -42,6 +42,14 @@ Documentation:
4242
- Add FootstepNet Envs to the project page (@cgaspard3333)
4343
- Added FRASA to the project page (@MarcDcls)
4444

45+
Release 2.4.1 (2024-12-20)
46+
--------------------------
47+
48+
Bug Fixes:
49+
^^^^^^^^^^
50+
- Fixed a bug introduced in v2.4.0 where the ``VecVideoRecorder`` would override videos
51+
52+
4553
Release 2.4.0 (2024-11-18)
4654
--------------------------
4755

stable_baselines3/common/vec_env/vec_video_recorder.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class VecVideoRecorder(VecEnvWrapper):
2929
:param name_prefix: Prefix to the video name
3030
"""
3131

32+
video_name: str
33+
video_path: str
34+
3235
def __init__(
3336
self,
3437
venv: VecEnv,
@@ -50,7 +53,7 @@ def __init__(
5053

5154
if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
5255
metadata = temp_env.get_attr("metadata")[0]
53-
else:
56+
else: # pragma: no cover # assume gym interface
5457
metadata = temp_env.metadata
5558

5659
self.env.metadata = metadata
@@ -67,15 +70,12 @@ def __init__(
6770
self.step_id = 0
6871
self.video_length = video_length
6972

70-
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
71-
self.video_path = os.path.join(self.video_folder, self.video_name)
72-
7373
self.recording = False
7474
self.recorded_frames: list[np.ndarray] = []
7575

7676
try:
7777
import moviepy # noqa: F401
78-
except ImportError as e:
78+
except ImportError as e: # pragma: no cover
7979
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e
8080

8181
def reset(self) -> VecEnvObs:
@@ -85,6 +85,9 @@ def reset(self) -> VecEnvObs:
8585
return obs
8686

8787
def _start_video_recorder(self) -> None:
88+
# Update video name and path
89+
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
90+
self.video_path = os.path.join(self.video_folder, self.video_name)
8891
self._start_recording()
8992
self._capture_frame()
9093

@@ -109,8 +112,6 @@ def _capture_frame(self) -> None:
109112
assert self.recording, "Cannot capture a frame, recording wasn't started."
110113

111114
frame = self.env.render()
112-
if isinstance(frame, list):
113-
frame = frame[-1]
114115

115116
if isinstance(frame, np.ndarray):
116117
self.recorded_frames.append(frame)
@@ -123,12 +124,12 @@ def _capture_frame(self) -> None:
123124
def close(self) -> None:
124125
"""Closes the wrapper then the video recorder."""
125126
VecEnvWrapper.close(self)
126-
if self.recording:
127+
if self.recording: # pragma: no cover
127128
self._stop_recording()
128129

129130
def _start_recording(self) -> None:
130131
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
131-
if self.recording:
132+
if self.recording: # pragma: no cover
132133
self._stop_recording()
133134

134135
self.recording = True
@@ -137,7 +138,7 @@ def _stop_recording(self) -> None:
137138
"""Stop current recording and saves the video."""
138139
assert self.recording, "_stop_recording was called, but no recording was started"
139140

140-
if len(self.recorded_frames) == 0:
141+
if len(self.recorded_frames) == 0: # pragma: no cover
141142
logger.warn("Ignored saving a video as there were zero frames to save.")
142143
else:
143144
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
@@ -150,5 +151,5 @@ def _stop_recording(self) -> None:
150151

151152
def __del__(self) -> None:
152153
"""Warn the user in case last video wasn't saved."""
153-
if len(self.recorded_frames) > 0:
154+
if len(self.recorded_frames) > 0: # pragma: no cover
154155
logger.warn("Unable to save last video! Did you call close()?")

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.5.0a0
1+
2.5.0a1

tests/test_vec_envs.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,17 @@
1111
import pytest
1212
from gymnasium import spaces
1313

14+
from stable_baselines3 import PPO
1415
from stable_baselines3.common.env_util import make_vec_env
1516
from stable_baselines3.common.monitor import Monitor
16-
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
17+
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder
18+
19+
try:
20+
import moviepy
21+
22+
have_moviepy = True
23+
except ImportError:
24+
have_moviepy = False
1725

1826
N_ENVS = 3
1927
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
@@ -624,3 +632,37 @@ def test_render(vec_env_class):
624632
vec_env.render()
625633

626634
vec_env.close()
635+
636+
637+
@pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed")
638+
def test_video_recorder(tmp_path):
639+
env_id = "CartPole-v1"
640+
video_folder = str(tmp_path)
641+
642+
vec_env = make_vec_env(env_id, n_envs=1)
643+
644+
# Wrap to check unwrapping works
645+
vec_env = VecNormalize(vec_env)
646+
647+
# Record the video starting at the first step
648+
vec_env = VecVideoRecorder(
649+
vec_env,
650+
video_folder,
651+
record_video_trigger=lambda x: x % 65 == 0,
652+
video_length=10,
653+
name_prefix=f"agent-{env_id}",
654+
)
655+
656+
model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0)
657+
658+
model.learn(total_timesteps=128)
659+
660+
# print all videos in video_folder, should be multiple step 0-100, step 1024-1124
661+
video_files = list(map(str, tmp_path.glob("*.mp4")))
662+
663+
# Clean up
664+
vec_env.close()
665+
666+
assert len(video_files) == 2
667+
assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0]
668+
assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1]

0 commit comments

Comments
 (0)