Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then
apt-get install -y libfreetype6-dev pkg-config

apt-get install -y libglfw3 libosmesa6 libglew-dev
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg \
libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libswresample-dev

if [ "${CU_VERSION:-}" == cpu ] ; then
apt-get upgrade -y libstdc++6
Expand Down Expand Up @@ -129,6 +130,7 @@ uv_pip_install \
wandb \
mlflow \
av \
torchcodec \
coverage \
transformers \
ninja \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ atari = ["gymnasium[atari]"]
dm_control = ["dm_control"]
replay_buffer = ["torch>=2.7.0"]
gym_continuous = ["gymnasium<1.0", "mujoco"]
rendering = ["moviepy<2.0.0"]
rendering = ["moviepy<2.0.0", "torchcodec>=0.10.0"]
tests = [
"pytest",
"pyyaml",
Expand Down
27 changes: 17 additions & 10 deletions test/test_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from tensordict import MemoryMappedTensor

from torchrl.envs import check_env_specs, GymEnv, ParallelEnv
from torchrl.record.loggers.common import _has_torchcodec, _has_tv
from torchrl.record.loggers.csv import CSVLogger
from torchrl.record.loggers.mlflow import _has_mlflow, _has_tv, MLFlowLogger
from torchrl.record.loggers.mlflow import _has_mlflow, MLFlowLogger
from torchrl.record.loggers.tensorboard import _has_tb, TensorboardLogger
from torchrl.record.loggers.trackio import _has_trackio, TrackioLogger
from torchrl.record.loggers.wandb import _has_wandb, WandbLogger
Expand All @@ -34,6 +35,8 @@
else:
TORCHVISION_VERSION = version.parse("0.0.1")

_has_mp4 = (_has_tv and TORCHVISION_VERSION < version.parse("0.22")) or _has_torchcodec

if _has_tb:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

Expand Down Expand Up @@ -170,10 +173,7 @@ def test_log_scalar(self, steps, tmpdir):

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
@pytest.mark.parametrize(
"video_format", ["pt", "memmap"] + ["mp4"] if _has_tv else []
)
@pytest.mark.skipif(
TORCHVISION_VERSION < version.parse("0.20.0"), reason="av compatibility bug"
"video_format", ["pt", "memmap"] + (["mp4"] if _has_mp4 else [])
)
def test_log_video(self, steps, video_format, tmpdir):
torch.manual_seed(0)
Expand Down Expand Up @@ -217,11 +217,18 @@ def test_log_video(self, steps, video_format, tmpdir):
)
assert torch.equal(video, logged_video), logged_video
elif video_format == "mp4":
import torchvision
if _has_torchcodec:
from torchcodec.decoders import VideoDecoder

logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
:, :1
]
logged_video = (
VideoDecoder(path)
.get_frames_in_range(start=0, stop=128)
.data[:, :1]
)
else:
logged_video = torchvision.io.read_video(path, output_format="TCHW")[0][
:, :1
]
logged_video = logged_video.unsqueeze(0)
torch.testing.assert_close(video, logged_video)

Expand Down Expand Up @@ -376,7 +383,7 @@ def test_log_scalar(self, steps, mlflow_fixture):
assert metric.value == values[i].item()

@pytest.mark.parametrize("steps", [None, [1, 10, 11]])
@pytest.mark.skipif(not _has_tv, reason="torchvision not installed")
@pytest.mark.skipif(not _has_mp4, reason="no mp4 video backend available")
def test_log_video(self, steps, mlflow_fixture):

logger, client = mlflow_fixture
Expand Down
59 changes: 59 additions & 0 deletions torchrl/record/loggers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,76 @@
from __future__ import annotations

import abc
import importlib.util
from collections.abc import Sequence
from typing import Any

import torch
from tensordict import TensorDictBase
from torch import Tensor

from torchrl._utils import implement_for

_has_tv = importlib.util.find_spec("torchvision") is not None
try:
from torchcodec.encoders import VideoEncoder # noqa: F401

_has_torchcodec = True
del VideoEncoder
except Exception:
_has_torchcodec = False


__all__ = ["Logger"]


@implement_for("torchvision", None, "0.22")
def _write_video(filename, video_array, **kwargs):
if not _has_tv:
raise ImportError(
"Writing mp4 videos with torchvision < 0.22 requires torchvision to be installed. "
"Install it with: pip install torchvision"
)
import torchvision

torchvision.io.write_video(filename, video_array, **kwargs)


@implement_for("torchvision", "0.22")
def _write_video(filename, video_array, **kwargs): # noqa: F811
if not _has_torchcodec:
raise ImportError(
"Writing mp4 videos with torchvision >= 0.22 requires torchcodec >= 0.10.0, "
"since torchvision.io.write_video was deprecated in 0.22 and removed in 0.24. "
"Install it with: pip install 'torchcodec>=0.10.0'"
)
from torchcodec.encoders import VideoEncoder

fps = kwargs.pop("fps", 30)
video_codec = kwargs.pop("video_codec", None)
options = dict(kwargs.pop("options", None) or {})
crf = options.pop("crf", None)
preset = options.pop("preset", None)
pixel_format = options.pop("pixel_format", None)

# VideoEncoder expects (N, C, H, W); callers pass (T, H, W, C)
video_array = video_array.permute(0, 3, 1, 2).contiguous()

to_file_kwargs = {}
if video_codec is not None:
to_file_kwargs["codec"] = video_codec
if crf is not None:
to_file_kwargs["crf"] = float(crf)
if preset is not None:
to_file_kwargs["preset"] = preset
if pixel_format is not None:
to_file_kwargs["pixel_format"] = pixel_format
if options:
to_file_kwargs["extra_options"] = options

VideoEncoder(frames=video_array, frame_rate=fps).to_file(filename, **to_file_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write_video from torchvision was a self-contained helper that relied on pyav. If you want to avoid having to maintain multiple version of your own _write_video with different sets of dependencies for each, you could consider just vendoring write_video inside of torchRL

You'd have a single optional dependency on pyav instead of having an optional dep on both TV and TC, and you wouldn't have to worry about TV's versions either

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write_video from torchvision was a self-contained helper that relied on pyav. If you want to avoid having to maintain multiple version of your own _write_video with different sets of dependencies for each, you could consider just vendoring write_video inside of torchRL

Thanks!

We still use tv's transforms so tv dep is not a big worry. But the point of removing an intermediary makes total sense!



def _make_metrics_safe(
metrics: dict[str, Any] | TensorDictBase,
*,
Expand Down
19 changes: 10 additions & 9 deletions torchrl/record/loggers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tensordict import MemoryMappedTensor
from torch import Tensor

from .common import Logger
from .common import _write_video, Logger


class CSVExperiment:
Expand Down Expand Up @@ -57,9 +57,12 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
- `"pt"`: uses :func:`~torch.save` to save the video tensor);
- `"memmap"`: saved the file as memory-mapped array (reading this file will require
the dtype and shape to be known at read time);
- `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video`
API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``.
These include ``preset``, ``crf`` and others.
- `"mp4"`: saves the file as an `.mp4` file. For torchvision < 0.22, this uses
:func:`~torchvision.io.write_video`; for torchvision >= 0.22, this uses
:class:`~torchcodec.encoders.VideoEncoder` since ``write_video`` was deprecated and
later removed. Any ``kwargs`` passed to ``add_video`` will be transmitted to the
underlying writer. These include ``video_codec``, ``options``
(a dict, e.g. ``{"crf": "23", "preset": "medium"}``), and others.
See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options.

"""
Expand Down Expand Up @@ -87,8 +90,6 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
elif self.video_format == "memmap":
MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath)
elif self.video_format == "mp4":
import torchvision

if vid_tensor.shape[-3] not in (3, 1):
raise RuntimeError(
"expected the video tensor to be of format [T, C, H, W] but the third channel "
Expand All @@ -99,7 +100,7 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
vid_tensor = vid_tensor.permute((0, 2, 3, 1))
vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3)
kwargs.setdefault("fps", self.video_fps)
torchvision.io.write_video(filepath, vid_tensor, **kwargs)
_write_video(filepath, vid_tensor, **kwargs)
else:
raise ValueError(
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
Expand Down Expand Up @@ -188,8 +189,8 @@ def log_video(
step (int, optional): The step at which the video is logged. Defaults to None.
**kwargs: other kwargs passed to the underlying video logger.

.. note:: If the video format is `mp4`, many more arguments can be passed to the :meth:`~torchvision.io.write_video`
function.
.. note:: If the video format is `mp4`, additional arguments (e.g. ``video_codec``, ``options``)
can be passed through to the underlying video writer.
For more information on video logging with :class:`~torchrl.record.loggers.csv.CSVLogger`,
see the :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video` documentation.
"""
Expand Down
5 changes: 2 additions & 3 deletions torchrl/record/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tensordict import TensorDictBase
from torch import Tensor

from torchrl.record.loggers.common import _make_metrics_safe, Logger
from torchrl.record.loggers.common import _make_metrics_safe, _write_video, Logger

_has_tv = importlib.util.find_spec("torchvision") is not None

Expand Down Expand Up @@ -99,7 +99,6 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``).
"""
import mlflow
import torchvision

if not _has_tv:
raise ImportError(
Expand All @@ -119,7 +118,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
with TemporaryDirectory() as temp_dir:
video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"
with open(os.path.join(temp_dir, video_name), "wb") as f:
torchvision.io.write_video(filename=f.name, video_array=video, fps=fps)
_write_video(f.name, video, fps=fps)
mlflow.log_artifact(f.name, "videos")

def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
Expand Down
Loading