Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
frames.sizes()[1] == 3,
"frame must have 3 channels (R, G, B), got ",
frames.sizes()[1]);
// TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
return frames.contiguous();
}

Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from ._audio_encoder import AudioEncoder # noqa
from ._video_encoder import VideoEncoder # noqa
92 changes: 92 additions & 0 deletions src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from pathlib import Path
from typing import Union

import torch
from torch import Tensor

from torchcodec import _core


class VideoEncoder:
"""A video encoder.

Args:
frames (``torch.Tensor``): The frames to encode. This must be a 4D
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
C is 3 channels (RGB), H is height, and W is width.
Values must be uint8 in the range ``[0, 255]``.
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
"""

def __init__(self, frames: Tensor, *, frame_rate: int):
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
if not isinstance(frames, Tensor):
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")
if frames.ndim != 4:
raise ValueError(f"Expected 4D frames, got {frames.shape = }.")
if frames.dtype != torch.uint8:
raise ValueError(f"Expected uint8 frames, got {frames.dtype = }.")
if frame_rate <= 0:
raise ValueError(f"{frame_rate = } must be > 0.")

self._frames = frames
self._frame_rate = frame_rate

def to_file(
self,
dest: Union[str, Path],
) -> None:
"""Encode frames into a file.

Args:
dest (str or ``pathlib.Path``): The path to the output file, e.g.
``video.mp4``. The extension of the file determines the video
container format.
"""
_core.encode_video_to_file(
frames=self._frames,
frame_rate=self._frame_rate,
filename=str(dest),
)

def to_tensor(
self,
format: str,
) -> Tensor:
"""Encode frames into raw bytes, as a 1D uint8 Tensor.

Args:
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif"

Returns:
Tensor: The raw encoded bytes as 4D uint8 Tensor.
"""
return _core.encode_video_to_tensor(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
)

def to_file_like(
self,
file_like,
format: str,
) -> None:
"""Encode frames into a file-like object.

Args:
file_like: A file-like object that supports ``write()`` and
``seek()`` methods, such as io.BytesIO(), an open file in binary
write mode, etc. Methods must have the following signature:
``write(data: bytes) -> int`` and ``seek(offset: int, whence:
int = 0) -> int``.
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
"mkv", "avi", "webm", "flv", or "gif".
"""
_core.encode_video_to_file_like(
frames=self._frames,
frame_rate=self._frame_rate,
format=format,
file_like=file_like,
)
114 changes: 113 additions & 1 deletion test/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from torchcodec.decoders import AudioDecoder

from torchcodec.encoders import AudioEncoder
from torchcodec.encoders import AudioEncoder, VideoEncoder

from .utils import (
assert_tensor_close_on_at_least,
Expand Down Expand Up @@ -564,3 +564,115 @@ def write(self, data):
RuntimeError, match="File like object must implement a seek method"
):
encoder.to_file_like(NoSeekMethod(), format="wav")


class TestVideoEncoder:
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_bad_input_parameterized(self, tmp_path, method):
if method == "to_file":
valid_params = dict(dest=str(tmp_path / "output.mp4"))
elif method == "to_tensor":
valid_params = dict(format="mp4")
elif method == "to_file_like":
valid_params = dict(file_like=io.BytesIO(), format="mp4")
else:
raise ValueError(f"Unknown method: {method}")

with pytest.raises(
ValueError, match="Expected uint8 frames, got frames.dtype = torch.float32"
):
encoder = VideoEncoder(
frames=torch.rand(5, 3, 64, 64),
frame_rate=30,
)
getattr(encoder, method)(**valid_params)

with pytest.raises(
ValueError, match=r"Expected 4D frames, got frames.shape = torch.Size"
):
encoder = VideoEncoder(
frames=torch.zeros(10),
frame_rate=30,
)
getattr(encoder, method)(**valid_params)

with pytest.raises(
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
):
encoder = VideoEncoder(
frames=torch.zeros((5, 2, 64, 64), dtype=torch.uint8),
frame_rate=30,
)
getattr(encoder, method)(**valid_params)

def test_bad_input(self, tmp_path):
encoder = VideoEncoder(
frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8),
frame_rate=30,
)

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
):
encoder.to_file("./file.bad_extension")

with pytest.raises(
RuntimeError,
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
):
encoder.to_file("./bad/path.mp3")

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
):
encoder.to_tensor(format="bad_format")

@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
def test_contiguity(self, method, tmp_path):
# Ensure that 2 sets of video frames with the same pixel values are encoded
# in the same way, regardless of their memory layout. Here we encode 2 equal
# frame tensors, one is contiguous while the other is non-contiguous.

num_frames, channels, height, width = 5, 3, 64, 64
contiguous_frames = torch.randint(
0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8
).contiguous()
assert contiguous_frames.is_contiguous()

# Permute NCHW to NHWC, then update the memory layout, then permute back
non_contiguous_frames = (
contiguous_frames.permute(0, 2, 3, 1).contiguous().permute(0, 3, 1, 2)
)
assert non_contiguous_frames.stride() != contiguous_frames.stride()
assert not non_contiguous_frames.is_contiguous()
assert non_contiguous_frames.is_contiguous(memory_format=torch.channels_last)

torch.testing.assert_close(
contiguous_frames, non_contiguous_frames, rtol=0, atol=0
)

def encode_to_tensor(frames):
if method == "to_file":
dest = str(tmp_path / "output.mp4")
VideoEncoder(frames, frame_rate=30).to_file(dest=dest)
with open(dest, "rb") as f:
return torch.frombuffer(f.read(), dtype=torch.uint8)
elif method == "to_tensor":
return VideoEncoder(frames, frame_rate=30).to_tensor(format="mp4")
elif method == "to_file_like":
file_like = io.BytesIO()
VideoEncoder(frames, frame_rate=30).to_file_like(
file_like, format="mp4"
)
return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8)
else:
raise ValueError(f"Unknown method: {method}")

encoded_from_contiguous = encode_to_tensor(contiguous_frames)
encoded_from_non_contiguous = encode_to_tensor(non_contiguous_frames)

torch.testing.assert_close(
encoded_from_contiguous, encoded_from_non_contiguous, rtol=0, atol=0
)
64 changes: 1 addition & 63 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,68 +1152,6 @@ def test_bad_input(self, tmp_path):


class TestVideoEncoderOps:
# TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
# TODO-VideoEncoder: Parametrize test after moving to test_encoders
def test_bad_input(self, tmp_path):
output_file = str(tmp_path / ".mp4")

with pytest.raises(
RuntimeError, match="frames must have uint8 dtype, got float"
):
encode_video_to_file(
frames=torch.rand((10, 3, 60, 60), dtype=torch.float),
frame_rate=10,
filename=output_file,
)

with pytest.raises(
RuntimeError, match=r"frames must have 4 dimensions \(N, C, H, W\), got 3"
):
encode_video_to_file(
frames=torch.randint(high=1, size=(3, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename=output_file,
)

with pytest.raises(
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
):
encode_video_to_file(
frames=torch.randint(high=1, size=(10, 2, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename=output_file,
)

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. The destination file is ./file.bad_extension, check the desired extension\?",
):
encode_video_to_file(
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename="./file.bad_extension",
)

with pytest.raises(
RuntimeError,
match=r"avio_open failed. The destination file is ./bad/path.mp3, make sure it's a valid path\?",
):
encode_video_to_file(
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
frame_rate=10,
filename="./bad/path.mp3",
)

with pytest.raises(
RuntimeError,
match=r"Couldn't allocate AVFormatContext. Check the desired format\? Got format=bad_format",
):
encode_video_to_tensor(
frames=torch.randint(high=255, size=(10, 3, 60, 60), dtype=torch.uint8),
frame_rate=10,
format="bad_format",
)

def decode(self, source=None) -> torch.Tensor:
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)

Expand Down Expand Up @@ -1406,7 +1344,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
)

def test_to_file_like_custom_file_object(self):
"""Test with a custom file-like object that implements write and seek."""
"""Test to_file_like with a custom file-like object that implements write and seek."""

class CustomFileObject:
def __init__(self):
Expand Down
Loading