Skip to content

Commit d75f0eb

Browse files
committed
pix_fmt added
1 parent afd5aba commit d75f0eb

File tree

6 files changed

+76
-27
lines changed

6 files changed

+76
-27
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include "Encoder.h"
55
#include "torch/types.h"
66

7+
extern "C" {
8+
#include <libavutil/pixdesc.h>
9+
}
10+
711
namespace facebook::torchcodec {
812

913
namespace {
@@ -635,15 +639,21 @@ void VideoEncoder::initializeEncoder(
635639
outWidth_ = inWidth_;
636640
outHeight_ = inHeight_;
637641

638-
// TODO-VideoEncoder: Enable other pixel formats
639-
// Let FFmpeg choose best pixel format to minimize loss
640-
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
641-
getSupportedPixelFormats(*avCodec), // List of supported formats
642-
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
643-
0, // No alpha channel
644-
nullptr // Discard conversion loss information
645-
);
646-
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")
642+
if (videoStreamOptions.pixelFormat.has_value()) {
643+
outPixelFormat_ =
644+
av_get_pix_fmt(videoStreamOptions.pixelFormat.value().c_str());
645+
TORCH_CHECK(
646+
outPixelFormat_ != AV_PIX_FMT_NONE,
647+
"Unknown pixel format: ",
648+
videoStreamOptions.pixelFormat.value());
649+
} else {
650+
const AVPixelFormat* formats = getSupportedPixelFormats(*avCodec);
651+
// If pixel formats are undefined for some reason, try a broadly supported
652+
// default.
653+
outPixelFormat_ = (formats && formats[0] != AV_PIX_FMT_NONE)
654+
? formats[0]
655+
: AV_PIX_FMT_YUV420P;
656+
}
647657

648658
// Configure codec parameters
649659
avCodecContext_->codec_id = avCodec->id;

src/torchcodec/_core/StreamOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ struct VideoStreamOptions {
4848
// TODO-VideoEncoder: Consider adding other optional fields here
4949
// (bit rate, gop size, max b frames, preset)
5050
std::optional<int> crf;
51+
52+
// Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p")
53+
// If not specified, uses codec's default format.
54+
std::optional<std::string> pixelFormat;
5155
};
5256

5357
struct AudioStreamOptions {

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, int? crf=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, int? crf=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, int? crf=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -603,8 +603,10 @@ void encode_video_to_file(
603603
const at::Tensor& frames,
604604
int64_t frame_rate,
605605
std::string_view file_name,
606+
std::optional<std::string> pixel_format = std::nullopt,
606607
std::optional<int64_t> crf = std::nullopt) {
607608
VideoStreamOptions videoStreamOptions;
609+
videoStreamOptions.pixelFormat = pixel_format;
608610
videoStreamOptions.crf = crf;
609611
VideoEncoder(
610612
frames,
@@ -618,9 +620,11 @@ at::Tensor encode_video_to_tensor(
618620
const at::Tensor& frames,
619621
int64_t frame_rate,
620622
std::string_view format,
623+
std::optional<std::string> pixel_format = std::nullopt,
621624
std::optional<int64_t> crf = std::nullopt) {
622625
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
623626
VideoStreamOptions videoStreamOptions;
627+
videoStreamOptions.pixelFormat = pixel_format;
624628
videoStreamOptions.crf = crf;
625629
return VideoEncoder(
626630
frames,
@@ -636,6 +640,7 @@ void _encode_video_to_file_like(
636640
int64_t frame_rate,
637641
std::string_view format,
638642
int64_t file_like_context,
643+
std::optional<std::string> pixel_format = std::nullopt,
639644
std::optional<int64_t> crf = std::nullopt) {
640645
auto fileLikeContext =
641646
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
@@ -644,6 +649,7 @@ void _encode_video_to_file_like(
644649
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
645650

646651
VideoStreamOptions videoStreamOptions;
652+
videoStreamOptions.pixelFormat = pixel_format;
647653
videoStreamOptions.crf = crf;
648654

649655
VideoEncoder encoder(

src/torchcodec/_core/ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def encode_video_to_file_like(
213213
format: str,
214214
file_like: Union[io.RawIOBase, io.BufferedIOBase],
215215
crf: Optional[int] = None,
216+
pixel_format: Optional[str] = None,
216217
) -> None:
217218
"""Encode video frames to a file-like object.
218219
@@ -222,6 +223,7 @@ def encode_video_to_file_like(
222223
format: Video format (e.g., "mp4", "mov", "mkv")
223224
file_like: File-like object that supports write() and seek() methods
224225
crf: Optional constant rate factor for encoding quality
226+
pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p")
225227
"""
226228
assert _pybind_ops is not None
227229

@@ -230,6 +232,7 @@ def encode_video_to_file_like(
230232
frame_rate,
231233
format,
232234
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
235+
pixel_format,
233236
crf,
234237
)
235238

@@ -319,6 +322,7 @@ def encode_video_to_file_abstract(
319322
frame_rate: int,
320323
filename: str,
321324
crf: Optional[int],
325+
pixel_format: Optional[str],
322326
) -> None:
323327
return
324328

@@ -329,6 +333,7 @@ def encode_video_to_tensor_abstract(
329333
frame_rate: int,
330334
format: str,
331335
crf: Optional[int],
336+
pixel_format: Optional[str],
332337
) -> torch.Tensor:
333338
return torch.empty([], dtype=torch.long)
334339

@@ -340,6 +345,7 @@ def _encode_video_to_file_like_abstract(
340345
format: str,
341346
file_like_context: int,
342347
crf: Optional[int] = None,
348+
pixel_format: Optional[str] = None,
343349
) -> None:
344350
return
345351

src/torchcodec/encoders/_video_encoder.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Union
2+
from typing import Optional, Union
33

44
import torch
55
from torch import Tensor
@@ -35,29 +35,38 @@ def __init__(self, frames: Tensor, *, frame_rate: int):
3535
def to_file(
3636
self,
3737
dest: Union[str, Path],
38+
*,
39+
pixel_format: Optional[str] = None,
3840
) -> None:
3941
"""Encode frames into a file.
4042
4143
Args:
4244
dest (str or ``pathlib.Path``): The path to the output file, e.g.
4345
``video.mp4``. The extension of the file determines the video
4446
container format.
47+
pixel_format (str, optional): The pixel format for encoding (e.g.,
48+
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
4549
"""
4650
_core.encode_video_to_file(
4751
frames=self._frames,
4852
frame_rate=self._frame_rate,
4953
filename=str(dest),
54+
pixel_format=pixel_format,
5055
)
5156

5257
def to_tensor(
5358
self,
5459
format: str,
60+
*,
61+
pixel_format: Optional[str] = None,
5562
) -> Tensor:
5663
"""Encode frames into raw bytes, as a 1D uint8 Tensor.
5764
5865
Args:
5966
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
6067
"mkv", "avi", "webm", "flv", or "gif"
68+
pixel_format (str, optional): The pixel format to encode frames into (e.g.,
69+
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
6170
6271
Returns:
6372
Tensor: The raw encoded bytes as 4D uint8 Tensor.
@@ -66,12 +75,15 @@ def to_tensor(
6675
frames=self._frames,
6776
frame_rate=self._frame_rate,
6877
format=format,
78+
pixel_format=pixel_format,
6979
)
7080

7181
def to_file_like(
7282
self,
7383
file_like,
7484
format: str,
85+
*,
86+
pixel_format: Optional[str] = None,
7587
) -> None:
7688
"""Encode frames into a file-like object.
7789
@@ -83,10 +95,13 @@ def to_file_like(
8395
int = 0) -> int``.
8496
format (str): The container format of the encoded frames, e.g. "mp4", "mov",
8597
"mkv", "avi", "webm", "flv", or "gif".
98+
pixel_format (str, optional): The pixel format for encoding (e.g.,
99+
"yuv420p", "yuv444p"). If not specified, uses codec's default format.
86100
"""
87101
_core.encode_video_to_file_like(
88102
frames=self._frames,
89103
frame_rate=self._frame_rate,
90104
format=format,
91105
file_like=file_like,
106+
pixel_format=pixel_format,
92107
)

test/test_ops.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,21 +1162,14 @@ def decode(self, source=None) -> torch.Tensor:
11621162
def test_video_encoder_round_trip(self, tmp_path, format, method):
11631163
# Test that decode(encode(decode(frames))) == decode(frames)
11641164
ffmpeg_version = get_ffmpeg_major_version()
1165-
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
1166-
# As a result, we skip the round trip test.
1167-
if ffmpeg_version == 6 and format != "webm":
1168-
pytest.skip(
1169-
f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test."
1170-
)
11711165
if format == "webm" and (
11721166
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
11731167
):
11741168
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
11751169
source_frames = self.decode(TEST_SRC_2_720P.path).data
11761170

1177-
params = dict(
1178-
frame_rate=30, crf=0
1179-
) # Frame rate is fixed with num frames decoded
1171+
# Frame rate is fixed with num frames decoded
1172+
params = dict(frame_rate=30, pixel_format="yuv444p", crf=0)
11801173
if method == "to_file":
11811174
encoded_path = str(tmp_path / f"encoder_output.{format}")
11821175
encode_video_to_file(
@@ -1274,16 +1267,18 @@ def test_against_to_file(self, tmp_path, format, method):
12741267
"avi",
12751268
"mkv",
12761269
"flv",
1277-
"gif",
12781270
pytest.param("webm", marks=pytest.mark.slow),
12791271
),
12801272
)
1281-
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
1273+
@pytest.mark.parametrize("pixel_format", ("yuv444p", "yuv420p"))
1274+
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format, pixel_format):
12821275
ffmpeg_version = get_ffmpeg_major_version()
12831276
if format == "webm" and (
12841277
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
12851278
):
12861279
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
1280+
if format in ("avi", "flv") and pixel_format == "yuv444p":
1281+
pytest.skip(f"Default codec for {format} does not support {pixel_format}")
12871282

12881283
source_frames = self.decode(TEST_SRC_2_720P.path).data
12891284

@@ -1303,13 +1298,15 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
13031298
"-f",
13041299
"rawvideo",
13051300
"-pix_fmt",
1306-
"rgb24",
1301+
"rgb24", # Input format
13071302
"-s",
13081303
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
13091304
"-r",
13101305
str(frame_rate),
13111306
"-i",
13121307
temp_raw_path,
1308+
"-pix_fmt",
1309+
pixel_format, # Output format
13131310
"-crf",
13141311
str(crf),
13151312
ffmpeg_encoded_path,
@@ -1322,6 +1319,7 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
13221319
frames=source_frames,
13231320
frame_rate=frame_rate,
13241321
filename=encoder_output_path,
1322+
pixel_format=pixel_format,
13251323
crf=crf,
13261324
)
13271325

@@ -1362,7 +1360,12 @@ def get_encoded_data(self):
13621360
source_frames = self.decode(TEST_SRC_2_720P.path).data
13631361
file_like = CustomFileObject()
13641362
encode_video_to_file_like(
1365-
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
1363+
source_frames,
1364+
frame_rate=30,
1365+
pixel_format="yuv444p",
1366+
crf=0,
1367+
format="mp4",
1368+
file_like=file_like,
13661369
)
13671370
decoded_samples = self.decode(file_like.get_encoded_data())
13681371

@@ -1380,7 +1383,12 @@ def test_to_file_like_real_file(self, tmp_path):
13801383

13811384
with open(file_path, "wb") as file_like:
13821385
encode_video_to_file_like(
1383-
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
1386+
source_frames,
1387+
frame_rate=30,
1388+
pixel_format="yuv444p",
1389+
crf=0,
1390+
format="mp4",
1391+
file_like=file_like,
13841392
)
13851393
decoded_samples = self.decode(str(file_path))
13861394

0 commit comments

Comments
 (0)