Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dd24dfa
Decoder-native resize public implementation
scotts Oct 27, 2025
3a2df84
Lint
scotts Oct 27, 2025
5344ab4
Merge branch 'main' of github.com:pytorch/torchcodec into transform_api
scotts Nov 6, 2025
98cf81b
Implement decoder native transforms API
scotts Nov 7, 2025
65c4ad7
Correct merge
scotts Nov 7, 2025
f300c70
Actually add new file
scotts Nov 7, 2025
2c3b7f0
Lint
scotts Nov 7, 2025
80e84b5
Better assert
scotts Nov 7, 2025
5ac60d8
Better comment
scotts Nov 7, 2025
531b40f
Top level transforms import
scotts Nov 7, 2025
cc333ac
Add the init file. Sigh.
scotts Nov 7, 2025
238a8ff
Linter now needs torchvision in the environment
scotts Nov 7, 2025
55d362c
Avoid missing import errors
scotts Nov 7, 2025
0d2492e
Better names, better docs
scotts Nov 8, 2025
a2da767
More testing, docstring editing
scotts Nov 10, 2025
2cd3f65
Changes
scotts Nov 11, 2025
4ff0186
Reference docs
scotts Nov 12, 2025
0f9eb62
Better docs
scotts Nov 12, 2025
8081298
Make make params private
scotts Nov 12, 2025
39ed9ac
Links to TorchVision.
scotts Nov 12, 2025
6e6815c
Rename conversion function
scotts Nov 12, 2025
363e688
Add no-torchvision job
scotts Nov 12, 2025
463674d
On second thought, let's not
scotts Nov 12, 2025
c20914c
Lists are not covariant?
scotts Nov 12, 2025
254641a
Just use an explicit type
scotts Nov 12, 2025
9b4186a
Pull tv2 inspection logic into decoder transform
scotts Nov 13, 2025
105c77f
Update conversion arg comment
scotts Nov 13, 2025
70b5976
Better importing, better docs
scotts Nov 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
run: python -m pip install --upgrade pip
- name: Install dependencies and FFmpeg
run: |
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
conda install "ffmpeg=7.0.1" pkg-config pybind11 -c conda-forge
ffmpeg -version
- name: Build and install torchcodec
Expand Down
17 changes: 17 additions & 0 deletions docs/source/api_ref_transforms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
.. _transforms:

=====================
torchcodec.transforms
=====================

.. currentmodule:: torchcodec.transforms

For a tutorial, see: TODO_DECODER_TRANSFORMS_TUTORIAL.

.. autosummary::
:toctree: generated/
:nosignatures:
:template: dataclass.rst

DecoderTransform
Resize
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __call__(self, filename):
intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"torchvision": ("https://docs.pytorch.org/vision/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ Encoding
api_ref_decoders
api_ref_encoders
api_ref_samplers
api_ref_transforms
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ files = src/torchcodec
show_error_codes = True
pretty = True
allow_redefinition = True
follow_untyped_imports = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 changes: 1 addition & 1 deletion src/torchcodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
# but that results in circular import.
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
from . import decoders, encoders, samplers # noqa
from . import decoders, encoders, samplers, transforms # noqa

try:
# Note that version.py is generated during install.
Expand Down
86 changes: 84 additions & 2 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import json
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
from typing import List, Literal, Optional, Sequence, Tuple, Union

import torch
from torch import device as torch_device, Tensor
from torch import device as torch_device, nn, Tensor

from torchcodec import _core as core, Frame, FrameBatch
from torchcodec.decoders._decoder_utils import (
_get_cuda_backend,
create_decoder,
ERROR_REPORTING_INSTRUCTIONS,
)
from torchcodec.transforms import DecoderTransform, Resize


class VideoDecoder:
Expand Down Expand Up @@ -66,6 +67,11 @@ class VideoDecoder:
probably is. Default: "exact".
Read more about this parameter in:
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
transforms (sequence of transform objects, optional): Sequence of transforms to be
applied to the decoded frames by the decoder itself, in order. Accepts both
:class:`~torchcodec.transforms.DecoderTransform` and
:class:`~torchvision.transforms.v2.Transform`
objects. Read more about this parameter in: TODO_DECODER_TRANSFORMS_TUTORIAL.
custom_frame_mappings (str, bytes, or file-like object, optional):
Mapping of frames to their metadata, typically generated via ffprobe.
This enables accurate frame seeking without requiring a full video scan.
Expand Down Expand Up @@ -104,6 +110,7 @@ def __init__(
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]] = None,
custom_frame_mappings: Optional[
Union[str, bytes, io.RawIOBase, io.BufferedReader]
] = None,
Expand Down Expand Up @@ -148,13 +155,16 @@ def __init__(

device_variant = _get_cuda_backend()

transform_specs = _make_transform_specs(transforms)

core.add_video_stream(
self._decoder,
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
device_variant=device_variant,
transform_specs=transform_specs,
custom_frame_mappings=custom_frame_mappings_data,
)

Expand Down Expand Up @@ -432,6 +442,78 @@ def _get_and_validate_stream_metadata(
)


def _convert_to_decoder_transforms(
transforms: Sequence[Union[DecoderTransform, nn.Module]],
) -> List[DecoderTransform]:
"""Convert a sequence of transforms that may contain TorchVision transform
objects into a list of only TorchCodec transform objects.

Args:
transforms: Squence of transform objects. The objects can be one of two
types:
1. torchcodec.transforms.DecoderTransform
2. torchvision.transforms.v2.Transform, but our type annotation
only mentions its base, nn.Module. We don't want to take a
hard dependency on TorchVision.

Returns:
List of DecoderTransform objects.
"""
try:
from torchvision.transforms import v2

tv_available = True
except ImportError:
tv_available = False

converted_transforms: list[DecoderTransform] = []
for transform in transforms:
if not isinstance(transform, DecoderTransform):
if not tv_available:
raise ValueError(
f"The supplied transform, {transform}, is not a TorchCodec "
" DecoderTransform. TorchCodec also accept TorchVision "
"v2 transforms, but TorchVision is not installed."
)
elif isinstance(transform, v2.Resize):
converted_transforms.append(Resize._from_torchvision(transform))
else:
raise ValueError(
f"Unsupported transform: {transform}. Transforms must be "
"either a TorchCodec DecoderTransform or a TorchVision "
"v2 transform."
)
else:
converted_transforms.append(transform)

return converted_transforms


def _make_transform_specs(
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
) -> str:
"""Given a sequence of transforms, turn those into the specification string
the core API expects.

Args:
transforms: Optional sequence of transform objects. The objects can be
one of two types:
1. torchcodec.transforms.DecoderTransform
2. torchvision.transforms.v2.Transform, but our type annotation
only mentions its base, nn.Module. We don't want to take a
hard dependency on TorchVision.

Returns:
String of transforms in the format the core API expects: transform
specifications separate by semicolons.
"""
if transforms is None:
return ""

transforms = _convert_to_decoder_transforms(transforms)
return ";".join([t._make_transform_spec() for t in transforms])


def _read_custom_frame_mappings(
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
) -> tuple[Tensor, Tensor, Tensor]:
Expand Down
7 changes: 7 additions & 0 deletions src/torchcodec/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._decoder_transforms import DecoderTransform, Resize # noqa
93 changes: 93 additions & 0 deletions src/torchcodec/transforms/_decoder_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from types import ModuleType
from typing import Sequence

from torch import nn


@dataclass
class DecoderTransform(ABC):
"""Base class for all decoder transforms.

A *decoder transform* is a transform that is applied by the decoder before
returning the decoded frame. Applying decoder transforms to frames
should be both faster and more memory efficient than receiving normally
decoded frames and applying the same kind of transform.

Most ``DecoderTransform`` objects have a complementary transform in TorchVision,
specificially in `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html>`_. For such transforms, we
ensure that:

1. The names are the same.
2. Default behaviors are the same.
3. The parameters for the ``DecoderTransform`` object are a subset of the
TorchVision :class:`~torchvision.transforms.v2.Transform` object.
4. Parameters with the same name control the same behavior and accept a
subset of the same types.
5. The difference between the frames returned by a decoder transform and
the complementary TorchVision transform are such that a model should
not be able to tell the difference.
"""

@abstractmethod
def _make_transform_spec(self) -> str:
pass


def import_torchvision_transforms_v2() -> ModuleType:
try:
from torchvision.transforms import v2
except ImportError as e:
raise RuntimeError(
"Cannot import TorchVision; this should never happen, please report a bug."
) from e
return v2


@dataclass
class Resize(DecoderTransform):
"""Resize the decoded frame to a given size.

Complementary TorchVision transform: :class:`~torchvision.transforms.v2.Resize`.
Interpolation is always bilinear. Anti-aliasing is always on.

Args:
size: (sequence of int): Desired output size. Must be a sequence of
the form (height, width).
"""

size: Sequence[int]

def _make_transform_spec(self) -> str:
assert len(self.size) == 2
return f"resize, {self.size[0]}, {self.size[1]}"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this class method below is new. Because I'm trying to exhaustively catch all of the v2.Resize options we don't support, the code for turning a v2.Resize into a torchcodec.transforms.Resize got more involved. Extrapolated across more transforms, this kind of logic would end up dominating the code in _video_decoder.py. By make this a private class method, we can put all logic related to what in v2.Resize we support and how to turn a v2.Resize into a torchcodec.transforms.Resize in one place.

Also, to state it explicit, _from_torchvision() and _make_params() are private methods so they're not publicly documented. Users shouldn't need to know about them.

@classmethod
def _from_torchvision(cls, resize_tv: nn.Module):
v2 = import_torchvision_transforms_v2()

assert isinstance(resize_tv, v2.Resize)

if resize_tv.interpolation is not v2.InterpolationMode.BILINEAR:
raise ValueError(
"TorchVision Resize transform must use bilinear interpolation."
)
if resize_tv.antialias is False:
raise ValueError(
"TorchVision Resize transform must have antialias enabled."
)
if resize_tv.size is None:
raise ValueError("TorchVision Resize transform must have a size specified.")
if len(resize_tv.size) != 2:
raise ValueError(
"TorchVision Resize transform must have a (height, width) "
f"pair for the size, got {resize_tv.size}."
)
return cls(size=resize_tv.size)
Loading
Loading