1414from tensordict import MemoryMappedTensor
1515from torch import Tensor
1616
17- from .common import Logger
17+ from .common import _write_video , Logger
1818
1919
2020class CSVExperiment :
@@ -57,9 +57,12 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
5757 - `"pt"`: uses :func:`~torch.save` to save the video tensor);
5858 - `"memmap"`: saved the file as memory-mapped array (reading this file will require
5959 the dtype and shape to be known at read time);
60- - `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video`
61- API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``.
62- These include ``preset``, ``crf`` and others.
60+ - `"mp4"`: saves the file as an `.mp4` file. For torchvision < 0.22, this uses
61+ :func:`~torchvision.io.write_video`; for torchvision >= 0.22, this uses
62+ :class:`~torchcodec.encoders.VideoEncoder` since ``write_video`` was deprecated and
63+ later removed. Any ``kwargs`` passed to ``add_video`` will be transmitted to the
64+ underlying writer. These include ``video_codec``, ``options``
65+ (a dict, e.g. ``{"crf": "23", "preset": "medium"}``), and others.
6366 See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options.
6467
6568 """
@@ -87,8 +90,6 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
8790 elif self .video_format == "memmap" :
8891 MemoryMappedTensor .from_tensor (vid_tensor , filename = filepath )
8992 elif self .video_format == "mp4" :
90- import torchvision
91-
9293 if vid_tensor .shape [- 3 ] not in (3 , 1 ):
9394 raise RuntimeError (
9495 "expected the video tensor to be of format [T, C, H, W] but the third channel "
@@ -99,7 +100,7 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
99100 vid_tensor = vid_tensor .permute ((0 , 2 , 3 , 1 ))
100101 vid_tensor = vid_tensor .expand (* vid_tensor .shape [:- 1 ], 3 )
101102 kwargs .setdefault ("fps" , self .video_fps )
102- torchvision . io . write_video (filepath , vid_tensor , ** kwargs )
103+ _write_video (filepath , vid_tensor , ** kwargs )
103104 else :
104105 raise ValueError (
105106 f"Unknown video format { self .video_format } . Must be one of 'pt', 'memmap' or 'mp4'."
@@ -188,8 +189,8 @@ def log_video(
188189 step (int, optional): The step at which the video is logged. Defaults to None.
189190 **kwargs: other kwargs passed to the underlying video logger.
190191
191- .. note:: If the video format is `mp4`, many more arguments can be passed to the :meth:`~torchvision.io.write_video`
192- function .
192+ .. note:: If the video format is `mp4`, additional arguments (e.g. ``video_codec``, ``options``)
193+ can be passed through to the underlying video writer .
193194 For more information on video logging with :class:`~torchrl.record.loggers.csv.CSVLogger`,
194195 see the :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video` documentation.
195196 """
0 commit comments