Skip to content

Commit 491ed0f

Browse files
authored
[Refactor] Upgrade to torchcodec for video export (#3540)
1 parent ff77695 commit 491ed0f

File tree

6 files changed

+92
-24
lines changed

6 files changed

+92
-24
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then
2525
apt-get install -y libfreetype6-dev pkg-config
2626

2727
apt-get install -y libglfw3 libosmesa6 libglew-dev
28-
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg
28+
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg \
29+
libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libswresample-dev
2930

3031
if [ "${CU_VERSION:-}" == cpu ] ; then
3132
apt-get upgrade -y libstdc++6
@@ -129,6 +130,7 @@ uv_pip_install \
129130
wandb \
130131
mlflow \
131132
av \
133+
torchcodec \
132134
coverage \
133135
transformers \
134136
ninja \

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ atari = ["gymnasium[atari]"]
5151
dm_control = ["dm_control"]
5252
replay_buffer = ["torch>=2.7.0"]
5353
gym_continuous = ["gymnasium<1.0", "mujoco"]
54-
rendering = ["moviepy<2.0.0"]
54+
rendering = ["moviepy<2.0.0", "torchcodec>=0.10.0"]
5555
tests = [
5656
"pytest",
5757
"pyyaml",

test/test_loggers.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from tensordict import MemoryMappedTensor
1919

2020
from torchrl.envs import check_env_specs, GymEnv, ParallelEnv
21+
from torchrl.record.loggers.common import _has_torchcodec, _has_tv
2122
from torchrl.record.loggers.csv import CSVLogger
22-
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
23+
from torchrl.record.loggers.mlflow import _has_mlflow, MLFlowLogger
2324
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
2425
from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger
2526
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
@@ -34,6 +35,8 @@
3435
else:
3536
TORCHVISION_VERSION = version.parse("0.0.1")
3637

38+
_has_mp4 = (_has_tv and TORCHVISION_VERSION < version.parse("0.22")) or _has_torchcodec
39+
3740
if _has_tb:
3841
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
3942

@@ -170,10 +173,7 @@ def test_log_scalar(self, steps, tmpdir):
170173

171174
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
172175
@pytest.mark.parametrize(
173-
"video_format", ["pt", "memmap"] + ["mp4"] if _has_tv else []
174-
)
175-
@pytest.mark.skipif(
176-
TORCHVISION_VERSION < version.parse("0.20.0"), reason="av compatibility bug"
176+
"video_format", ["pt", "memmap"] + (["mp4"] if _has_mp4 else [])
177177
)
178178
def test_log_video(self, steps, video_format, tmpdir):
179179
torch.manual_seed(0)
@@ -217,11 +217,18 @@ def test_log_video(self, steps, video_format, tmpdir):
217217
)
218218
assert torch.equal(video, logged_video), logged_video
219219
elif video_format == "mp4":
220-
import torchvision
220+
if _has_torchcodec:
221+
from torchcodec.decoders import VideoDecoder
221222

222-
logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
223-
:, :1
224-
]
223+
logged_video = (
224+
VideoDecoder(path)
225+
.get_frames_in_range(start=0, stop=128)
226+
.data[:, :1]
227+
)
228+
else:
229+
logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
230+
:, :1
231+
]
225232
logged_video = logged_video.unsqueeze(0)
226233
torch.testing.assert_close(video, logged_video)
227234

@@ -376,7 +383,7 @@ def test_log_scalar(self, steps, mlflow_fixture):
376383
assert metric.value == values[i].item()
377384

378385
@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
379-
@pytest.mark.skipif(not _has_tv, reason="torchvision not installed")
386+
@pytest.mark.skipif(not _has_mp4, reason="no mp4 video backend available")
380387
def test_log_video(self, steps, mlflow_fixture):
381388

382389
logger, client = mlflow_fixture

torchrl/record/loggers/common.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,76 @@
55
from __future__ import annotations
66

77
import abc
8+
import importlib.util
89
from collections.abc import Sequence
910
from typing import Any
1011

1112
import torch
1213
from tensordict import TensorDictBase
1314
from torch import Tensor
1415

16+
from torchrl._utils import implement_for
17+
18+
_has_tv = importlib.util.find_spec("torchvision") is not None
19+
try:
20+
from torchcodec.encoders import VideoEncoder # noqa: F401
21+
22+
_has_torchcodec = True
23+
del VideoEncoder
24+
except Exception:
25+
_has_torchcodec = False
26+
1527

1628
__all__ = ["Logger"]
1729

1830

31+
@implement_for("torchvision", None, "0.22")
32+
def _write_video(filename, video_array, **kwargs):
33+
if not _has_tv:
34+
raise ImportError(
35+
"Writing mp4 videos with torchvision < 0.22 requires torchvision to be installed. "
36+
"Install it with: pip install torchvision"
37+
)
38+
import torchvision
39+
40+
torchvision.io.write_video(filename, video_array, **kwargs)
41+
42+
43+
@implement_for("torchvision", "0.22")
44+
def _write_video(filename, video_array, **kwargs): # noqa: F811
45+
if not _has_torchcodec:
46+
raise ImportError(
47+
"Writing mp4 videos with torchvision >= 0.22 requires torchcodec >= 0.10.0, "
48+
"since torchvision.io.write_video was deprecated in 0.22 and removed in 0.24. "
49+
"Install it with: pip install 'torchcodec>=0.10.0'"
50+
)
51+
from torchcodec.encoders import VideoEncoder
52+
53+
fps = kwargs.pop("fps", 30)
54+
video_codec = kwargs.pop("video_codec", None)
55+
options = dict(kwargs.pop("options", None) or {})
56+
crf = options.pop("crf", None)
57+
preset = options.pop("preset", None)
58+
pixel_format = options.pop("pixel_format", None)
59+
60+
# VideoEncoder expects (N, C, H, W); callers pass (T, H, W, C)
61+
video_array = video_array.permute(0, 3, 1, 2).contiguous()
62+
63+
to_file_kwargs = {}
64+
if video_codec is not None:
65+
to_file_kwargs["codec"] = video_codec
66+
if crf is not None:
67+
to_file_kwargs["crf"] = float(crf)
68+
if preset is not None:
69+
to_file_kwargs["preset"] = preset
70+
if pixel_format is not None:
71+
to_file_kwargs["pixel_format"] = pixel_format
72+
if options:
73+
to_file_kwargs["extra_options"] = options
74+
75+
VideoEncoder(frames=video_array, frame_rate=fps).to_file(filename, **to_file_kwargs)
76+
77+
1978
def _make_metrics_safe(
2079
metrics: dict[str, Any] | TensorDictBase,
2180
*,

torchrl/record/loggers/csv.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tensordict import MemoryMappedTensor
1515
from torch import Tensor
1616

17-
from .common import Logger
17+
from .common import _write_video, Logger
1818

1919

2020
class CSVExperiment:
@@ -57,9 +57,12 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
5757
- `"pt"`: uses :func:`~torch.save` to save the video tensor);
5858
- `"memmap"`: saved the file as memory-mapped array (reading this file will require
5959
the dtype and shape to be known at read time);
60-
- `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video`
61-
API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``.
62-
These include ``preset``, ``crf`` and others.
60+
- `"mp4"`: saves the file as an `.mp4` file. For torchvision < 0.22, this uses
61+
:func:`~torchvision.io.write_video`; for torchvision >= 0.22, this uses
62+
:class:`~torchcodec.encoders.VideoEncoder` since ``write_video`` was deprecated and
63+
later removed. Any ``kwargs`` passed to ``add_video`` will be transmitted to the
64+
underlying writer. These include ``video_codec``, ``options``
65+
(a dict, e.g. ``{"crf": "23", "preset": "medium"}``), and others.
6366
See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options.
6467
6568
"""
@@ -87,8 +90,6 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
8790
elif self.video_format == "memmap":
8891
MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath)
8992
elif self.video_format == "mp4":
90-
import torchvision
91-
9293
if vid_tensor.shape[-3] not in (3, 1):
9394
raise RuntimeError(
9495
"expected the video tensor to be of format [T, C, H, W] but the third channel "
@@ -99,7 +100,7 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
99100
vid_tensor = vid_tensor.permute((0, 2, 3, 1))
100101
vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3)
101102
kwargs.setdefault("fps", self.video_fps)
102-
torchvision.io.write_video(filepath, vid_tensor, **kwargs)
103+
_write_video(filepath, vid_tensor, **kwargs)
103104
else:
104105
raise ValueError(
105106
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
@@ -188,8 +189,8 @@ def log_video(
188189
step (int, optional): The step at which the video is logged. Defaults to None.
189190
**kwargs: other kwargs passed to the underlying video logger.
190191
191-
.. note:: If the video format is `mp4`, many more arguments can be passed to the :meth:`~torchvision.io.write_video`
192-
function.
192+
.. note:: If the video format is `mp4`, additional arguments (e.g. ``video_codec``, ``options``)
193+
can be passed through to the underlying video writer.
193194
For more information on video logging with :class:`~torchrl.record.loggers.csv.CSVLogger`,
194195
see the :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video` documentation.
195196
"""

torchrl/record/loggers/mlflow.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tensordict import TensorDictBase
1515
from torch import Tensor
1616

17-
from torchrl.record.loggers.common import _make_metrics_safe, Logger
17+
from torchrl.record.loggers.common import _make_metrics_safe, _write_video, Logger
1818

1919
_has_tv = importlib.util.find_spec("torchvision") is not None
2020

@@ -99,7 +99,6 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
9999
supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``).
100100
"""
101101
import mlflow
102-
import torchvision
103102

104103
if not _has_tv:
105104
raise ImportError(
@@ -119,7 +118,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
119118
with TemporaryDirectory() as temp_dir:
120119
video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"
121120
with open(os.path.join(temp_dir, video_name), "wb") as f:
122-
torchvision.io.write_video(filename=f.name, video_array=video, fps=fps)
121+
_write_video(f.name, video, fps=fps)
123122
mlflow.log_artifact(f.name, "videos")
124123

125124
def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821

0 commit comments

Comments
 (0)