Skip to content

Commit 45647a1

Browse files
authored
Add pixel_format to VideoEncoder API (#1027)
1 parent dc87228 commit 45647a1

File tree

7 files changed

+131
-30
lines changed

7 files changed

+131
-30
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include "Encoder.h"
55
#include "torch/types.h"
66

7+
extern "C" {
8+
#include <libavutil/pixdesc.h>
9+
}
10+
711
namespace facebook::torchcodec {
812

913
namespace {
@@ -534,6 +538,36 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
534538
return frames.contiguous();
535539
}
536540

541+
AVPixelFormat validatePixelFormat(
542+
const AVCodec& avCodec,
543+
const std::string& targetPixelFormat) {
544+
AVPixelFormat pixelFormat = av_get_pix_fmt(targetPixelFormat.c_str());
545+
546+
// Validate that the encoder supports this pixel format
547+
const AVPixelFormat* supportedFormats = getSupportedPixelFormats(avCodec);
548+
if (supportedFormats != nullptr) {
549+
for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) {
550+
if (supportedFormats[i] == pixelFormat) {
551+
return pixelFormat;
552+
}
553+
}
554+
}
555+
556+
std::stringstream errorMsg;
557+
// av_get_pix_fmt failed to find a pix_fmt
558+
if (pixelFormat == AV_PIX_FMT_NONE) {
559+
errorMsg << "Unknown pixel format: " << targetPixelFormat;
560+
} else {
561+
errorMsg << "Specified pixel format " << targetPixelFormat
562+
<< " is not supported by the " << avCodec.name << " encoder.";
563+
}
564+
// Build error message, similar to FFmpeg's error log
565+
errorMsg << "\nSupported pixel formats for " << avCodec.name << ":";
566+
for (int i = 0; supportedFormats[i] != AV_PIX_FMT_NONE; ++i) {
567+
errorMsg << " " << av_get_pix_fmt_name(supportedFormats[i]);
568+
}
569+
TORCH_CHECK(false, errorMsg.str());
570+
}
537571
} // namespace
538572

539573
VideoEncoder::~VideoEncoder() {
@@ -635,15 +669,19 @@ void VideoEncoder::initializeEncoder(
635669
outWidth_ = inWidth_;
636670
outHeight_ = inHeight_;
637671

638-
// TODO-VideoEncoder: Enable other pixel formats
639-
// Let FFmpeg choose best pixel format to minimize loss
640-
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
641-
getSupportedPixelFormats(*avCodec), // List of supported formats
642-
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
643-
0, // No alpha channel
644-
nullptr // Discard conversion loss information
645-
);
646-
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")
672+
if (videoStreamOptions.pixelFormat.has_value()) {
673+
outPixelFormat_ =
674+
validatePixelFormat(*avCodec, videoStreamOptions.pixelFormat.value());
675+
} else {
676+
const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec);
677+
// Use first listed pixel format as default (often yuv420p).
678+
// This is similar to FFmpeg's logic:
679+
// https://www.ffmpeg.org/doxygen/4.0/decode_8c_source.html#l01087
680+
// If pixel formats are undefined for some reason, try yuv420p
681+
outPixelFormat_ = (formats && formats[0] != AV_PIX_FMT_NONE)
682+
? formats[0]
683+
: AV_PIX_FMT_YUV420P;
684+
}
647685

648686
// Configure codec parameters
649687
avCodecContext_->codec_id = avCodec->id;

src/torchcodec/_core/StreamOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ struct VideoStreamOptions {
4848
// TODO-VideoEncoder: Consider adding other optional fields here
4949
// (bit rate, gop size, max b frames, preset)
5050
std::optional<int> crf;
51+
52+
// Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p")
53+
// If not specified, uses codec's default format.
54+
std::optional<std::string> pixelFormat;
5155
};
5256

5357
struct AudioStreamOptions {

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, int? crf=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, int? crf=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, int? crf=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -603,8 +603,10 @@ void encode_video_to_file(
603603
const at::Tensor& frames,
604604
int64_t frame_rate,
605605
std::string_view file_name,
606+
std::optional<std::string> pixel_format = std::nullopt,
606607
std::optional<int64_t> crf = std::nullopt) {
607608
VideoStreamOptions videoStreamOptions;
609+
videoStreamOptions.pixelFormat = pixel_format;
608610
videoStreamOptions.crf = crf;
609611
VideoEncoder(
610612
frames,
@@ -618,9 +620,11 @@ at::Tensor encode_video_to_tensor(
618620
const at::Tensor& frames,
619621
int64_t frame_rate,
620622
std::string_view format,
623+
std::optional<std::string> pixel_format = std::nullopt,
621624
std::optional<int64_t> crf = std::nullopt) {
622625
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
623626
VideoStreamOptions videoStreamOptions;
627+
videoStreamOptions.pixelFormat = pixel_format;
624628
videoStreamOptions.crf = crf;
625629
return VideoEncoder(
626630
frames,
@@ -636,6 +640,7 @@ void _encode_video_to_file_like(
636640
int64_t frame_rate,
637641
std::string_view format,
638642
int64_t file_like_context,
643+
std::optional<std::string> pixel_format = std::nullopt,
639644
std::optional<int64_t> crf = std::nullopt) {
640645
auto fileLikeContext =
641646
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
@@ -644,6 +649,7 @@ void _encode_video_to_file_like(
644649
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
645650

646651
VideoStreamOptions videoStreamOptions;
652+
videoStreamOptions.pixelFormat = pixel_format;
647653
videoStreamOptions.crf = crf;
648654

649655
VideoEncoder encoder(

src/torchcodec/_core/ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def encode_video_to_file_like(
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216216
crf: Optional[int] = None,
217+
pixel_format: Optional[str] = None,
217218
) -> None:
218219
"""Encode video frames to a file-like object.
219220
@@ -223,6 +224,7 @@ def encode_video_to_file_like(
223224
format: Video format (e.g., "mp4", "mov", "mkv")
224225
file_like: File-like object that supports write() and seek() methods
225226
crf: Optional constant rate factor for encoding quality
227+
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
226228
"""
227229
assert _pybind_ops is not None
228230

@@ -231,6 +233,7 @@ def encode_video_to_file_like(
231233
frame_rate,
232234
format,
233235
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
236+
pixel_format,
234237
crf,
235238
)
236239

@@ -319,7 +322,8 @@ def encode_video_to_file_abstract(
319322
frames: torch.Tensor,
320323
frame_rate: int,
321324
filename: str,
322-
crf: Optional[int],
325+
crf: Optional[int] = None,
326+
pixel_format: Optional[str] = None,
323327
) -> None:
324328
return
325329

@@ -329,7 +333,8 @@ def encode_video_to_tensor_abstract(
329333
frames: torch.Tensor,
330334
frame_rate: int,
331335
format: str,
332-
crf: Optional[int],
336+
crf: Optional[int] = None,
337+
pixel_format: Optional[str] = None,
333338
) -> torch.Tensor:
334339
return torch.empty([], dtype=torch.long)
335340

@@ -341,6 +346,7 @@ def _encode_video_to_file_like_abstract(
341346
format: str,
342347
file_like_context: int,
343348
crf: Optional[int] = None,
349+
pixel_format: Optional[str] = None,
344350
) -> None:
345351
return
346352

src/torchcodec/encoders/_video_encoder.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Union
2+
from typing import Optional, Union
33

44
import torch
55
from torch import Tensor
@@ -35,29 +35,38 @@ def __init__(self, frames: Tensor, *, frame_rate: int):
3535
def to_file(
3636
self,
3737
dest: Union[str, Path],
38+
*,
39+
pixel_format: Optional[str] = None,
3840
) -> None:
3941
"""Encode frames into a file.
4042
4143
Args:
4244
dest (str or ``pathlib.Path``): The path to the output file, e.g.
4345
``video.mp4``. The extension of the file determines the video
4446
container format.
47+
pixel_format (str, optional): The pixel format for encoding (e.g.,
48+
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
4549
"""
4650
_core.encode_video_to_file(
4751
frames=self._frames,
4852
frame_rate=self._frame_rate,
4953
filename=str(dest),
54+
pixel_format=pixel_format,
5055
)
5156

5257
def to_tensor(
5358
self,
5459
format: str,
60+
*,
61+
pixel_format: Optional[str] = None,
5562
) -> Tensor:
5663
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
5764
5865
Args:
5966
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
6067
"mkv", "avi", "webm", "flv", or "gif"
68+
pixel_format (str, optional): The pixel format to encode frames into (e.g.,
69+
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
6170
6271
Returns:
6372
Tensor: The raw encoded bytes as 4D uint8 Tensor.
@@ -66,12 +75,15 @@ def to_tensor(
6675
frames=self._frames,
6776
frame_rate=self._frame_rate,
6877
format=format,
78+
pixel_format=pixel_format,
6979
)
7080

7181
def to_file_like(
7282
self,
7383
file_like,
7484
format: str,
85+
*,
86+
pixel_format: Optional[str] = None,
7587
) -> None:
7688
"""Encode frames into a file-like object.
7789
@@ -83,10 +95,13 @@ def to_file_like(
8395
int = 0) -> int``.
8496
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
8597
"mkv", "avi", "webm", "flv", or "gif".
98+
pixel_format (str, optional): The pixel format for encoding (e.g.,
99+
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
86100
"""
87101
_core.encode_video_to_file_like(
88102
frames=self._frames,
89103
frame_rate=self._frame_rate,
90104
format=format,
91105
file_like=file_like,
106+
pixel_format=pixel_format,
92107
)

test/test_encoders.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,30 @@ def test_bad_input(self, tmp_path):
629629
):
630630
encoder.to_tensor(format="bad_format")
631631

632+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
633+
def test_pixel_format_errors(self, method, tmp_path):
634+
frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8)
635+
encoder = VideoEncoder(frames, frame_rate=30)
636+
637+
if method == "to_file":
638+
valid_params = dict(dest=str(tmp_path / "output.mp4"))
639+
elif method == "to_tensor":
640+
valid_params = dict(format="mp4")
641+
elif method == "to_file_like":
642+
valid_params = dict(file_like=io.BytesIO(), format="mp4")
643+
644+
with pytest.raises(
645+
RuntimeError,
646+
match=r"Unknown pixel format: invalid_pix_fmt[\s\S]*Supported pixel formats.*yuv420p",
647+
):
648+
getattr(encoder, method)(**valid_params, pixel_format="invalid_pix_fmt")
649+
650+
with pytest.raises(
651+
RuntimeError,
652+
match=r"Specified pixel format rgb24 is not supported[\s\S]*Supported pixel formats.*yuv420p",
653+
):
654+
getattr(encoder, method)(**valid_params, pixel_format="rgb24")
655+
632656
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
633657
def test_contiguity(self, method, tmp_path):
634658
# Ensure that 2 sets of video frames with the same pixel values are encoded

0 commit comments

Comments
 (0)