Skip to content

Commit 5a1da00

Browse files
Acme Contributorcopybara-github
authored andcommitted
Remove video wrapper to avoid directly calling ffmpeg.
PiperOrigin-RevId: 870846808 Change-Id: Ie8848b988fdaaac97e21e8fd43ac781369c5b011
1 parent 15bd3e8 commit 5a1da00

File tree

1 file changed

+149
-22
lines changed

1 file changed

+149
-22
lines changed

acme/wrappers/video.py

Lines changed: 149 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,64 @@
1818
the `dm_control/tutorial.ipynb` file.
1919
"""
2020

21+
import os.path
22+
import tempfile
2123
from typing import Callable, Optional, Sequence, Tuple, Union
2224

23-
from absl import logging
2425
from acme.utils import paths
2526
from acme.wrappers import base
2627
import dm_env
28+
29+
import matplotlib
30+
matplotlib.use('Agg') # Switch to headless 'Agg' to inhibit figure rendering.
31+
import matplotlib.animation as anim # pylint: disable=g-import-not-at-top
32+
import matplotlib.pyplot as plt
2733
import numpy as np
2834

35+
# Internal imports.
36+
# Make sure you have FFMpeg configured.
2937

3038
def make_animation(
31-
frames: Sequence[np.ndarray],
32-
frame_rate: float,
33-
figsize: Optional[Union[float, Tuple[int, int]]],
34-
):
39+
frames: Sequence[np.ndarray], frame_rate: float,
40+
figsize: Optional[Union[float, Tuple[int, int]]]) -> anim.Animation:
3541
"""Generates a matplotlib animation from a stack of frames."""
36-
logging.warning(
37-
'make_animation is deprecated and currently acts as a no-op in order to '
38-
'avoid using ffmpeg directly. The old behavior can be restored by '
39-
'replacing the direct call to ffmpeg within matplotlib.'
40-
)
41-
del frames
42-
del frame_rate
43-
del figsize
44-
return None
42+
43+
# Set animation characteristics.
44+
if figsize is None:
45+
height, width, _ = frames[0].shape
46+
elif isinstance(figsize, tuple):
47+
height, width = figsize
48+
else:
49+
diagonal = figsize
50+
height, width, _ = frames[0].shape
51+
scale_factor = diagonal / np.sqrt(height**2 + width**2)
52+
width *= scale_factor
53+
height *= scale_factor
54+
55+
dpi = 70
56+
interval = int(round(1e3 / frame_rate)) # Time (in ms) between frames.
57+
58+
# Create and configure the figure.
59+
fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
60+
ax.set_axis_off()
61+
ax.set_aspect('equal')
62+
ax.set_position([0, 0, 1, 1])
63+
64+
# Initialize the first frame.
65+
im = ax.imshow(frames[0])
66+
67+
# Create the function that will modify the frame, creating an animation.
68+
def update(frame):
69+
im.set_data(frame)
70+
return [im]
71+
72+
return anim.FuncAnimation(
73+
fig=fig,
74+
func=update,
75+
frames=frames,
76+
interval=interval,
77+
blit=True,
78+
repeat=False)
4579

4680

4781
class VideoWrapper(base.EnvironmentWrapper):
@@ -66,23 +100,77 @@ def __init__(
66100
figsize: Optional[Union[float, Tuple[int, int]]] = None,
67101
to_html: bool = True,
68102
):
69-
logging.warning(
70-
'VideoWrapper is deprecated and currently acts as a no-op in order to '
71-
'avoid using ffmpeg directly. The old behavior can be restored by '
72-
'replacing the direct call to ffmpeg within matplotlib.'
73-
)
74103
super(VideoWrapper, self).__init__(environment)
104+
self._path = process_path(path, 'videos')
105+
self._filename = filename
106+
self._record_every = record_every
107+
self._frame_rate = frame_rate
108+
self._frames = []
109+
self._counter = 0
110+
self._figsize = figsize
111+
self._to_html = to_html
112+
113+
def _render_frame(self, observation):
114+
"""Renders a frame from the given environment observation."""
115+
return observation
116+
117+
def _write_frames(self):
118+
"""Writes frames to video."""
119+
if self._counter % self._record_every == 0:
120+
animation = make_animation(self._frames, self._frame_rate, self._figsize)
121+
path_without_extension = os.path.join(
122+
self._path, f'{self._filename}_{self._counter:04d}'
123+
)
124+
if self._to_html:
125+
path = path_without_extension + '.html'
126+
video = animation.to_html5_video()
127+
with open(path, 'w') as f:
128+
f.write(video)
129+
else:
130+
path = path_without_extension + '.m4v'
131+
# Animation.save can save only locally. Save first and copy using
132+
# gfile.
133+
with tempfile.TemporaryDirectory() as tmp_dir:
134+
tmp_path = os.path.join(tmp_dir, 'temp.m4v')
135+
animation.save(tmp_path)
136+
with open(path, 'wb') as f:
137+
with open(tmp_path, 'rb') as g:
138+
f.write(g.read())
139+
140+
# Clear the frame buffer whether a video was generated or not.
141+
self._frames = []
142+
143+
def _append_frame(self, observation):
144+
"""Appends a frame to the sequence of frames."""
145+
if self._counter % self._record_every == 0:
146+
self._frames.append(self._render_frame(observation))
75147

76148
def step(self, action) -> dm_env.TimeStep:
77-
return self.environment.step(action)
149+
timestep = self.environment.step(action)
150+
self._append_frame(timestep.observation)
151+
return timestep
78152

79153
def reset(self) -> dm_env.TimeStep:
80-
return self.environment.reset()
154+
# If the frame buffer is nonempty, flush it and record video
155+
if self._frames:
156+
self._write_frames()
157+
self._counter += 1
158+
timestep = self.environment.reset()
159+
self._append_frame(timestep.observation)
160+
return timestep
81161

82162
def make_html_animation(self):
83-
return None
163+
if self._frames:
164+
return make_animation(self._frames, self._frame_rate,
165+
self._figsize).to_html5_video()
166+
else:
167+
raise ValueError('make_html_animation should be called after running a '
168+
'trajectory and before calling reset().')
84169

85170
def close(self):
171+
if self._frames:
172+
self._write_frames()
173+
self._frames = []
86174
self.environment.close()
87175

88176

@@ -127,3 +215,42 @@ def __init__(self,
127215
self._camera_id = camera_id
128216
self._height = height
129217
self._width = width
218+
219+
def _render_frame(self, unused_observation):
220+
del unused_observation
221+
222+
# We've checked above that this attribute should exist. Pytype won't like
223+
# it if we just try and do self.environment.physics, so we use the slightly
224+
# grosser version below.
225+
physics = getattr(self.environment, 'physics')
226+
227+
if self._camera_id is not None:
228+
frame = physics.render(
229+
camera_id=self._camera_id, height=self._height, width=self._width)
230+
else:
231+
# If camera_id is None, we create a minimal canvas that will accommodate
232+
# physics.model.ncam frames, and render all of them on a grid.
233+
num_cameras = physics.model.ncam
234+
num_columns = int(np.ceil(np.sqrt(num_cameras)))
235+
num_rows = int(np.ceil(float(num_cameras)/num_columns))
236+
height = self._height
237+
width = self._width
238+
239+
# Make a black canvas.
240+
frame = np.zeros((num_rows*height, num_columns*width, 3), dtype=np.uint8)
241+
242+
for col in range(num_columns):
243+
for row in range(num_rows):
244+
245+
camera_id = row*num_columns + col
246+
247+
if camera_id >= num_cameras:
248+
break
249+
250+
subframe = physics.render(
251+
camera_id=camera_id, height=height, width=width)
252+
253+
# Place the frame in the appropriate rectangle on the pixel canvas.
254+
frame[row*height:(row+1)*height, col*width:(col+1)*width] = subframe
255+
256+
return frame

0 commit comments

Comments
 (0)