1818the `dm_control/tutorial.ipynb` file.
1919"""
2020
21+ import os .path
22+ import tempfile
2123from typing import Callable , Optional , Sequence , Tuple , Union
2224
23- from absl import logging
2425from acme .utils import paths
2526from acme .wrappers import base
2627import dm_env
28+
29+ import matplotlib
30+ matplotlib .use ('Agg' ) # Switch to headless 'Agg' to inhibit figure rendering.
31+ import matplotlib .animation as anim # pylint: disable=g-import-not-at-top
32+ import matplotlib .pyplot as plt
2733import numpy as np
2834
35+ # Internal imports.
36+ # Make sure you have FFMpeg configured.
2937
3038def make_animation (
31- frames : Sequence [np .ndarray ],
32- frame_rate : float ,
33- figsize : Optional [Union [float , Tuple [int , int ]]],
34- ):
39+ frames : Sequence [np .ndarray ], frame_rate : float ,
40+ figsize : Optional [Union [float , Tuple [int , int ]]]) -> anim .Animation :
3541 """Generates a matplotlib animation from a stack of frames."""
36- logging .warning (
37- 'make_animation is deprecated and currently acts as a no-op in order to '
38- 'avoid using ffmpeg directly. The old behavior can be restored by '
39- 'replacing the direct call to ffmpeg within matplotlib.'
40- )
41- del frames
42- del frame_rate
43- del figsize
44- return None
42+
43+ # Set animation characteristics.
44+ if figsize is None :
45+ height , width , _ = frames [0 ].shape
46+ elif isinstance (figsize , tuple ):
47+ height , width = figsize
48+ else :
49+ diagonal = figsize
50+ height , width , _ = frames [0 ].shape
51+ scale_factor = diagonal / np .sqrt (height ** 2 + width ** 2 )
52+ width *= scale_factor
53+ height *= scale_factor
54+
55+ dpi = 70
56+ interval = int (round (1e3 / frame_rate )) # Time (in ms) between frames.
57+
58+ # Create and configure the figure.
59+ fig , ax = plt .subplots (1 , 1 , figsize = (width / dpi , height / dpi ), dpi = dpi )
60+ ax .set_axis_off ()
61+ ax .set_aspect ('equal' )
62+ ax .set_position ([0 , 0 , 1 , 1 ])
63+
64+ # Initialize the first frame.
65+ im = ax .imshow (frames [0 ])
66+
67+ # Create the function that will modify the frame, creating an animation.
68+ def update (frame ):
69+ im .set_data (frame )
70+ return [im ]
71+
72+ return anim .FuncAnimation (
73+ fig = fig ,
74+ func = update ,
75+ frames = frames ,
76+ interval = interval ,
77+ blit = True ,
78+ repeat = False )
4579
4680
4781class VideoWrapper (base .EnvironmentWrapper ):
@@ -66,23 +100,77 @@ def __init__(
66100 figsize : Optional [Union [float , Tuple [int , int ]]] = None ,
67101 to_html : bool = True ,
68102 ):
69- logging .warning (
70- 'VideoWrapper is deprecated and currently acts as a no-op in order to '
71- 'avoid using ffmpeg directly. The old behavior can be restored by '
72- 'replacing the direct call to ffmpeg within matplotlib.'
73- )
74103 super (VideoWrapper , self ).__init__ (environment )
104+ self ._path = process_path (path , 'videos' )
105+ self ._filename = filename
106+ self ._record_every = record_every
107+ self ._frame_rate = frame_rate
108+ self ._frames = []
109+ self ._counter = 0
110+ self ._figsize = figsize
111+ self ._to_html = to_html
112+
113+ def _render_frame (self , observation ):
114+ """Renders a frame from the given environment observation."""
115+ return observation
116+
117+ def _write_frames (self ):
118+ """Writes frames to video."""
119+ if self ._counter % self ._record_every == 0 :
120+ animation = make_animation (self ._frames , self ._frame_rate , self ._figsize )
121+ path_without_extension = os .path .join (
122+ self ._path , f'{ self ._filename } _{ self ._counter :04d} '
123+ )
124+ if self ._to_html :
125+ path = path_without_extension + '.html'
126+ video = animation .to_html5_video ()
127+ with open (path , 'w' ) as f :
128+ f .write (video )
129+ else :
130+ path = path_without_extension + '.m4v'
131+ # Animation.save can save only locally. Save first and copy using
132+ # gfile.
133+ with tempfile .TemporaryDirectory () as tmp_dir :
134+ tmp_path = os .path .join (tmp_dir , 'temp.m4v' )
135+ animation .save (tmp_path )
136+ with open (path , 'wb' ) as f :
137+ with open (tmp_path , 'rb' ) as g :
138+ f .write (g .read ())
139+
140+ # Clear the frame buffer whether a video was generated or not.
141+ self ._frames = []
142+
143+ def _append_frame (self , observation ):
144+ """Appends a frame to the sequence of frames."""
145+ if self ._counter % self ._record_every == 0 :
146+ self ._frames .append (self ._render_frame (observation ))
75147
76148 def step (self , action ) -> dm_env .TimeStep :
77- return self .environment .step (action )
149+ timestep = self .environment .step (action )
150+ self ._append_frame (timestep .observation )
151+ return timestep
78152
79153 def reset (self ) -> dm_env .TimeStep :
80- return self .environment .reset ()
154+ # If the frame buffer is nonempty, flush it and record video
155+ if self ._frames :
156+ self ._write_frames ()
157+ self ._counter += 1
158+ timestep = self .environment .reset ()
159+ self ._append_frame (timestep .observation )
160+ return timestep
81161
82162 def make_html_animation (self ):
83- return None
163+ if self ._frames :
164+ return make_animation (self ._frames , self ._frame_rate ,
165+ self ._figsize ).to_html5_video ()
166+ else :
167+ raise ValueError ('make_html_animation should be called after running a '
168+ 'trajectory and before calling reset().' )
84169
85170 def close (self ):
171+ if self ._frames :
172+ self ._write_frames ()
173+ self ._frames = []
86174 self .environment .close ()
87175
88176
@@ -127,3 +215,42 @@ def __init__(self,
127215 self ._camera_id = camera_id
128216 self ._height = height
129217 self ._width = width
218+
219+ def _render_frame (self , unused_observation ):
220+ del unused_observation
221+
222+ # We've checked above that this attribute should exist. Pytype won't like
223+ # it if we just try and do self.environment.physics, so we use the slightly
224+ # grosser version below.
225+ physics = getattr (self .environment , 'physics' )
226+
227+ if self ._camera_id is not None :
228+ frame = physics .render (
229+ camera_id = self ._camera_id , height = self ._height , width = self ._width )
230+ else :
231+ # If camera_id is None, we create a minimal canvas that will accommodate
232+ # physics.model.ncam frames, and render all of them on a grid.
233+ num_cameras = physics .model .ncam
234+ num_columns = int (np .ceil (np .sqrt (num_cameras )))
235+ num_rows = int (np .ceil (float (num_cameras )/ num_columns ))
236+ height = self ._height
237+ width = self ._width
238+
239+ # Make a black canvas.
240+ frame = np .zeros ((num_rows * height , num_columns * width , 3 ), dtype = np .uint8 )
241+
242+ for col in range (num_columns ):
243+ for row in range (num_rows ):
244+
245+ camera_id = row * num_columns + col
246+
247+ if camera_id >= num_cameras :
248+ break
249+
250+ subframe = physics .render (
251+ camera_id = camera_id , height = height , width = width )
252+
253+ # Place the frame in the appropriate rectangle on the pixel canvas.
254+ frame [row * height :(row + 1 )* height , col * width :(col + 1 )* width ] = subframe
255+
256+ return frame
0 commit comments