Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ atari = ["gymnasium[atari]"]
dm_control = ["dm_control"]
replay_buffer = ["torch>=2.7.0"]
gym_continuous = ["gymnasium<1.0", "mujoco"]
rendering = ["moviepy<2.0.0"]
rendering = ["moviepy<2.0.0", "torchcodec>=0.10.0"]
tests = [
"pytest",
"pyyaml",
Expand Down
53 changes: 53 additions & 0 deletions torchrl/record/loggers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,70 @@
from __future__ import annotations

import abc
import importlib.util
from collections.abc import Sequence
from typing import Any

import torch
from tensordict import TensorDictBase
from torch import Tensor

from torchrl._utils import implement_for

_has_tv = importlib.util.find_spec("torchvision") is not None
_has_torchcodec = importlib.util.find_spec("torchcodec") is not None


__all__ = ["Logger"]


@implement_for("torchvision", None, "0.22")
def _write_video(filename, video_array, **kwargs):
if not _has_tv:
raise ImportError(
"Writing mp4 videos with torchvision < 0.22 requires torchvision to be installed. "
"Install it with: pip install torchvision"
)
import torchvision

torchvision.io.write_video(filename, video_array, **kwargs)


@implement_for("torchvision", "0.22")
def _write_video(filename, video_array, **kwargs): # noqa: F811
if not _has_torchcodec:
raise ImportError(
"Writing mp4 videos with torchvision >= 0.22 requires torchcodec >= 0.10.0, "
"since torchvision.io.write_video was deprecated in 0.22 and removed in 0.24. "
"Install it with: pip install 'torchcodec>=0.10.0'"
)
from torchcodec.encoders import VideoEncoder

fps = kwargs.pop("fps", 30)
video_codec = kwargs.pop("video_codec", None)
options = dict(kwargs.pop("options", None) or {})
crf = options.pop("crf", None)
preset = options.pop("preset", None)
pixel_format = options.pop("pixel_format", None)

# VideoEncoder expects (N, C, H, W); callers pass (T, H, W, C)
video_array = video_array.permute(0, 3, 1, 2).contiguous()

to_file_kwargs = {}
if video_codec is not None:
to_file_kwargs["codec"] = video_codec
if crf is not None:
to_file_kwargs["crf"] = float(crf)
if preset is not None:
to_file_kwargs["preset"] = preset
if pixel_format is not None:
to_file_kwargs["pixel_format"] = pixel_format
if options:
to_file_kwargs["extra_options"] = options

VideoEncoder(frames=video_array, frame_rate=fps).to_file(filename, **to_file_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write_video from torchvision was a self-contained helper that relied on pyav. If you want to avoid having to maintain multiple version of your own _write_video with different sets of dependencies for each, you could consider just vendoring write_video inside of torchRL

You'd have a single optional dependency on pyav instead of having an optional dep on both TV and TC, and you wouldn't have to worry about TV's versions either

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write_video from torchvision was a self-contained helper that relied on pyav. If you want to avoid having to maintain multiple version of your own _write_video with different sets of dependencies for each, you could consider just vendoring write_video inside of torchRL

Thanks!

We still use tv's transforms so tv dep is not a big worry. But the point of removing an intermediary makes total sense!



def _make_metrics_safe(
metrics: dict[str, Any] | TensorDictBase,
*,
Expand Down
19 changes: 10 additions & 9 deletions torchrl/record/loggers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tensordict import MemoryMappedTensor
from torch import Tensor

from .common import Logger
from .common import _write_video, Logger


class CSVExperiment:
Expand Down Expand Up @@ -57,9 +57,12 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
- `"pt"`: uses :func:`~torch.save` to save the video tensor);
- `"memmap"`: saved the file as memory-mapped array (reading this file will require
the dtype and shape to be known at read time);
- `"mp4"`: saves the file as an `.mp4` file using torchvision :func:`~torchvision.io.write_video`
API. Any ``kwargs`` passed to ``add_video`` will be transmitted to ``write_video``.
These include ``preset``, ``crf`` and others.
- `"mp4"`: saves the file as an `.mp4` file. For torchvision < 0.22, this uses
:func:`~torchvision.io.write_video`; for torchvision >= 0.22, this uses
:class:`~torchcodec.encoders.VideoEncoder` since ``write_video`` was deprecated and
later removed. Any ``kwargs`` passed to ``add_video`` will be transmitted to the
underlying writer. These include ``video_codec``, ``options``
(a dict, e.g. ``{"crf": "23", "preset": "medium"}``), and others.
See ffmpeg's doc (https://trac.ffmpeg.org/wiki/Encode/H.264) for some more information of the video format options.

"""
Expand Down Expand Up @@ -87,8 +90,6 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
elif self.video_format == "memmap":
MemoryMappedTensor.from_tensor(vid_tensor, filename=filepath)
elif self.video_format == "mp4":
import torchvision

if vid_tensor.shape[-3] not in (3, 1):
raise RuntimeError(
"expected the video tensor to be of format [T, C, H, W] but the third channel "
Expand All @@ -99,7 +100,7 @@ def add_video(self, tag, vid_tensor, global_step: int | None = None, **kwargs):
vid_tensor = vid_tensor.permute((0, 2, 3, 1))
vid_tensor = vid_tensor.expand(*vid_tensor.shape[:-1], 3)
kwargs.setdefault("fps", self.video_fps)
torchvision.io.write_video(filepath, vid_tensor, **kwargs)
_write_video(filepath, vid_tensor, **kwargs)
else:
raise ValueError(
f"Unknown video format {self.video_format}. Must be one of 'pt', 'memmap' or 'mp4'."
Expand Down Expand Up @@ -188,8 +189,8 @@ def log_video(
step (int, optional): The step at which the video is logged. Defaults to None.
**kwargs: other kwargs passed to the underlying video logger.

.. note:: If the video format is `mp4`, many more arguments can be passed to the :meth:`~torchvision.io.write_video`
function.
.. note:: If the video format is `mp4`, additional arguments (e.g. ``video_codec``, ``options``)
can be passed through to the underlying video writer.
For more information on video logging with :class:`~torchrl.record.loggers.csv.CSVLogger`,
see the :meth:`~torchrl.record.loggers.csv.CSVExperiment.add_video` documentation.
"""
Expand Down
5 changes: 2 additions & 3 deletions torchrl/record/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from tensordict import TensorDictBase
from torch import Tensor

from torchrl.record.loggers.common import _make_metrics_safe, Logger
from torchrl.record.loggers.common import _make_metrics_safe, _write_video, Logger

_has_tv = importlib.util.find_spec("torchvision") is not None

Expand Down Expand Up @@ -99,7 +99,6 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``).
"""
import mlflow
import torchvision

if not _has_tv:
raise ImportError(
Expand All @@ -119,7 +118,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
with TemporaryDirectory() as temp_dir:
video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"
with open(os.path.join(temp_dir, video_name), "wb") as f:
torchvision.io.write_video(filename=f.name, video_array=video, fps=fps)
_write_video(f.name, video, fps=fps)
mlflow.log_artifact(f.name, "videos")

def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821
Expand Down
Loading