Skip to content

Commit e504970

Browse files
committed
update type to float, update tests
1 parent 1ea235a commit e504970

File tree

6 files changed

+71
-54
lines changed

6 files changed

+71
-54
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ VideoEncoder::~VideoEncoder() {
662662

663663
VideoEncoder::VideoEncoder(
664664
const torch::Tensor& frames,
665-
int frameRate,
665+
double frameRate,
666666
std::string_view fileName,
667667
const VideoStreamOptions& videoStreamOptions)
668668
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
@@ -694,7 +694,7 @@ VideoEncoder::VideoEncoder(
694694

695695
VideoEncoder::VideoEncoder(
696696
const torch::Tensor& frames,
697-
int frameRate,
697+
double frameRate,
698698
std::string_view formatName,
699699
std::unique_ptr<AVIOContextHolder> avioContextHolder,
700700
const VideoStreamOptions& videoStreamOptions)
@@ -787,9 +787,10 @@ void VideoEncoder::initializeEncoder(
787787
avCodecContext_->width = outWidth_;
788788
avCodecContext_->height = outHeight_;
789789
avCodecContext_->pix_fmt = outPixelFormat_;
790-
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
791-
avCodecContext_->time_base = {1, inFrameRate_};
792-
avCodecContext_->framerate = {inFrameRate_, 1};
790+
// TODO-VideoEncoder: Add and utilize output frame_rate option
791+
AVRational frameRate = av_d2q(inFrameRate_, INT_MAX);
792+
avCodecContext_->time_base = av_inv_q(frameRate);
793+
avCodecContext_->framerate = frameRate;
793794

794795
// Set flag for containers that require extradata to be in the codec context
795796
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {

src/torchcodec/_core/Encoder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ class VideoEncoder {
143143

144144
VideoEncoder(
145145
const torch::Tensor& frames,
146-
int frameRate,
146+
double frameRate,
147147
std::string_view fileName,
148148
const VideoStreamOptions& videoStreamOptions);
149149

150150
VideoEncoder(
151151
const torch::Tensor& frames,
152-
int frameRate,
152+
double frameRate,
153153
std::string_view formatName,
154154
std::unique_ptr<AVIOContextHolder> avioContextHolder,
155155
const VideoStreamOptions& videoStreamOptions);
@@ -172,7 +172,7 @@ class VideoEncoder {
172172
UniqueSwsContext swsContext_;
173173

174174
const torch::Tensor frames_;
175-
int inFrameRate_;
175+
double inFrameRate_;
176176

177177
int inWidth_ = -1;
178178
int inHeight_ = -1;

src/torchcodec/_core/custom_ops.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3737
m.def(
3838
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3939
m.def(
40-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
40+
"encode_video_to_file(Tensor frames, float frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4141
m.def(
42-
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
42+
"encode_video_to_tensor(Tensor frames, float frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor");
4343
m.def(
44-
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
44+
"_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()");
4545
m.def(
4646
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4747
m.def(
@@ -611,7 +611,7 @@ void _encode_audio_to_file_like(
611611

612612
void encode_video_to_file(
613613
const at::Tensor& frames,
614-
int64_t frame_rate,
614+
double frame_rate,
615615
std::string_view file_name,
616616
std::optional<std::string_view> codec = std::nullopt,
617617
std::optional<std::string_view> pixel_format = std::nullopt,
@@ -629,17 +629,12 @@ void encode_video_to_file(
629629
unflattenExtraOptions(extra_options.value());
630630
}
631631

632-
VideoEncoder(
633-
frames,
634-
validateInt64ToInt(frame_rate, "frame_rate"),
635-
file_name,
636-
videoStreamOptions)
637-
.encode();
632+
VideoEncoder(frames, frame_rate, file_name, videoStreamOptions).encode();
638633
}
639634

640635
at::Tensor encode_video_to_tensor(
641636
const at::Tensor& frames,
642-
int64_t frame_rate,
637+
double frame_rate,
643638
std::string_view format,
644639
std::optional<std::string_view> codec = std::nullopt,
645640
std::optional<std::string_view> pixel_format = std::nullopt,
@@ -660,7 +655,7 @@ at::Tensor encode_video_to_tensor(
660655

661656
return VideoEncoder(
662657
frames,
663-
validateInt64ToInt(frame_rate, "frame_rate"),
658+
frame_rate,
664659
format,
665660
std::move(avioContextHolder),
666661
videoStreamOptions)
@@ -669,7 +664,7 @@ at::Tensor encode_video_to_tensor(
669664

670665
void _encode_video_to_file_like(
671666
const at::Tensor& frames,
672-
int64_t frame_rate,
667+
double frame_rate,
673668
std::string_view format,
674669
int64_t file_like_context,
675670
std::optional<std::string_view> codec = std::nullopt,
@@ -696,7 +691,7 @@ void _encode_video_to_file_like(
696691

697692
VideoEncoder encoder(
698693
frames,
699-
validateInt64ToInt(frame_rate, "frame_rate"),
694+
frame_rate,
700695
format,
701696
std::move(avioContextHolder),
702697
videoStreamOptions);

src/torchcodec/_core/ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def encode_audio_to_file_like(
210210

211211
def encode_video_to_file_like(
212212
frames: torch.Tensor,
213-
frame_rate: int,
213+
frame_rate: float,
214214
format: str,
215215
file_like: Union[io.RawIOBase, io.BufferedIOBase],
216216
codec: Optional[str] = None,
@@ -329,7 +329,7 @@ def _encode_audio_to_file_like_abstract(
329329
@register_fake("torchcodec_ns::encode_video_to_file")
330330
def encode_video_to_file_abstract(
331331
frames: torch.Tensor,
332-
frame_rate: int,
332+
frame_rate: float,
333333
filename: str,
334334
codec: Optional[str] = None,
335335
pixel_format: Optional[str] = None,
@@ -343,7 +343,7 @@ def encode_video_to_file_abstract(
343343
@register_fake("torchcodec_ns::encode_video_to_tensor")
344344
def encode_video_to_tensor_abstract(
345345
frames: torch.Tensor,
346-
frame_rate: int,
346+
frame_rate: float,
347347
format: str,
348348
codec: Optional[str] = None,
349349
pixel_format: Optional[str] = None,
@@ -357,7 +357,7 @@ def encode_video_to_tensor_abstract(
357357
@register_fake("torchcodec_ns::_encode_video_to_file_like")
358358
def _encode_video_to_file_like_abstract(
359359
frames: torch.Tensor,
360-
frame_rate: int,
360+
frame_rate: float,
361361
format: str,
362362
file_like_context: int,
363363
codec: Optional[str] = None,

src/torchcodec/encoders/_video_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ class VideoEncoder:
1515
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
1616
C is 3 channels (RGB), H is height, and W is width.
1717
Values must be uint8 in the range ``[0, 255]``.
18-
frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
18+
frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
1919
"""
2020

21-
def __init__(self, frames: Tensor, *, frame_rate: int):
21+
def __init__(self, frames: Tensor, *, frame_rate: float):
2222
torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder")
2323
if not isinstance(frames, Tensor):
2424
raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.")

test/test_encoders.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,13 @@ def write(self, data):
570570

571571
class TestVideoEncoder:
572572
def decode(self, source=None) -> torch.Tensor:
573-
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
573+
return VideoDecoder(source).get_frames_in_range(start=0, stop=60).data
574+
575+
def decode_and_get_frame_rate(self, source=None):
576+
decoder = VideoDecoder(source)
577+
frames = decoder.get_frames_in_range(start=0, stop=60).data
578+
frame_rate = decoder.metadata.average_fps
579+
return frames, frame_rate
574580

575581
def _get_video_metadata(self, file_path, fields):
576582
"""Helper function to get video metadata from a file using ffprobe."""
@@ -826,26 +832,25 @@ def test_round_trip(self, tmp_path, format, method):
826832
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
827833
):
828834
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
829-
source_frames = self.decode(TEST_SRC_2_720P.path).data
835+
source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path)
830836

831-
# Frame rate is fixed with num frames decoded
832-
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
837+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
833838

834839
if method == "to_file":
835840
encoded_path = str(tmp_path / f"encoder_output.{format}")
836841
encoder.to_file(dest=encoded_path, pixel_format="yuv444p", crf=0)
837-
round_trip_frames = self.decode(encoded_path).data
842+
round_trip_frames = self.decode(encoded_path)
838843
elif method == "to_tensor":
839844
encoded_tensor = encoder.to_tensor(
840845
format=format, pixel_format="yuv444p", crf=0
841846
)
842-
round_trip_frames = self.decode(encoded_tensor).data
847+
round_trip_frames = self.decode(encoded_tensor)
843848
elif method == "to_file_like":
844849
file_like = io.BytesIO()
845850
encoder.to_file_like(
846851
file_like=file_like, format=format, pixel_format="yuv444p", crf=0
847852
)
848-
round_trip_frames = self.decode(file_like.getvalue()).data
853+
round_trip_frames = self.decode(file_like.getvalue())
849854
else:
850855
raise ValueError(f"Unknown method: {method}")
851856

@@ -878,8 +883,8 @@ def test_against_to_file(self, tmp_path, format, method):
878883
):
879884
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
880885

881-
source_frames = self.decode(TEST_SRC_2_720P.path).data
882-
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
886+
source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path)
887+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
883888

884889
encoded_file = tmp_path / f"output.{format}"
885890
encoder.to_file(dest=encoded_file, crf=0)
@@ -892,8 +897,8 @@ def test_against_to_file(self, tmp_path, format, method):
892897
encoded_output = file_like.getvalue()
893898

894899
torch.testing.assert_close(
895-
self.decode(encoded_file).data,
896-
self.decode(encoded_output).data,
900+
self.decode(encoded_file),
901+
self.decode(encoded_output),
897902
atol=0,
898903
rtol=0,
899904
)
@@ -936,15 +941,14 @@ def test_video_encoder_against_ffmpeg_cli(
936941
if format in ("avi", "flv") and pixel_format == "yuv444p":
937942
pytest.skip(f"Default codec for {format} does not support {pixel_format}")
938943

939-
source_frames = self.decode(TEST_SRC_2_720P.path).data
944+
source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path)
940945

941946
# Encode with FFmpeg CLI
942947
temp_raw_path = str(tmp_path / "temp_input.raw")
943948
with open(temp_raw_path, "wb") as f:
944949
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())
945950

946951
ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
947-
frame_rate = 30
948952
# Some codecs (ex. MPEG4) do not support CRF or preset.
949953
# Flags not supported by the selected codec will be ignored.
950954
ffmpeg_cmd = [
@@ -983,15 +987,15 @@ def test_video_encoder_against_ffmpeg_cli(
983987
crf=crf,
984988
preset=preset,
985989
)
986-
encoder_frames = self.decode(encoder_output_path).data
990+
encoder_frames = self.decode(encoder_output_path)
987991
elif method == "to_tensor":
988992
encoded_output = encoder.to_tensor(
989993
format=format,
990994
pixel_format=pixel_format,
991995
crf=crf,
992996
preset=preset,
993997
)
994-
encoder_frames = self.decode(encoded_output).data
998+
encoder_frames = self.decode(encoded_output)
995999
elif method == "to_file_like":
9961000
file_like = io.BytesIO()
9971001
encoder.to_file_like(
@@ -1001,7 +1005,7 @@ def test_video_encoder_against_ffmpeg_cli(
10011005
crf=crf,
10021006
preset=preset,
10031007
)
1004-
encoder_frames = self.decode(file_like.getvalue()).data
1008+
encoder_frames = self.decode(file_like.getvalue())
10051009
else:
10061010
raise ValueError(f"Unknown method: {method}")
10071011

@@ -1047,24 +1051,24 @@ def seek(self, offset, whence=0):
10471051
def get_encoded_data(self):
10481052
return self._file.getvalue()
10491053

1050-
source_frames = self.decode(TEST_SRC_2_720P.path).data
1051-
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
1054+
source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path)
1055+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
10521056

10531057
file_like = CustomFileObject()
10541058
encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0)
10551059
decoded_frames = self.decode(file_like.get_encoded_data())
10561060

10571061
torch.testing.assert_close(
1058-
decoded_frames.data,
1062+
decoded_frames,
10591063
source_frames,
10601064
atol=2,
10611065
rtol=0,
10621066
)
10631067

10641068
def test_to_file_like_real_file(self, tmp_path):
10651069
"""Test to_file_like with a real file opened in binary write mode."""
1066-
source_frames = self.decode(TEST_SRC_2_720P.path).data
1067-
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
1070+
source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path)
1071+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
10681072

10691073
file_path = tmp_path / "test_file_like.mp4"
10701074

@@ -1073,15 +1077,15 @@ def test_to_file_like_real_file(self, tmp_path):
10731077
decoded_frames = self.decode(str(file_path))
10741078

10751079
torch.testing.assert_close(
1076-
decoded_frames.data,
1080+
decoded_frames,
10771081
source_frames,
10781082
atol=2,
10791083
rtol=0,
10801084
)
10811085

10821086
def test_to_file_like_bad_methods(self):
1083-
source_frames = self.decode(TEST_SRC_2_720P.path).data
1084-
encoder = VideoEncoder(frames=source_frames, frame_rate=30)
1087+
source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path)
1088+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
10851089

10861090
class NoWriteMethod:
10871091
def seek(self, offset, whence=0):
@@ -1174,8 +1178,8 @@ def test_codec_spec_vs_impl_equivalence(self, tmp_path, codec_spec, codec_impl):
11741178
== codec_spec
11751179
)
11761180

1177-
frames_spec = self.decode(spec_output).data
1178-
frames_impl = self.decode(impl_output).data
1181+
frames_spec = self.decode(spec_output)
1182+
frames_impl = self.decode(impl_output)
11791183
torch.testing.assert_close(frames_spec, frames_impl, rtol=0, atol=0)
11801184

11811185
@pytest.mark.skipif(in_fbcode(), reason="ffprobe not available")
@@ -1210,3 +1214,20 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range
12101214
assert metadata["profile"].lower() == expected_profile
12111215
assert metadata["color_space"] == colorspace
12121216
assert metadata["color_range"] == color_range
1217+
1218+
@pytest.mark.parametrize("frame_rate", [29.97, 59.94, 5.001])
1219+
def test_fractional_frame_rate(self, tmp_path, frame_rate):
1220+
source_frames = torch.zeros((10, 3, 64, 64), dtype=torch.uint8)
1221+
encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate)
1222+
output_path = str(tmp_path / "output.mp4")
1223+
encoder.to_file(dest=output_path)
1224+
# Assert the encoded frame rate via file metadata
1225+
metadata = self._get_video_metadata(output_path, fields=["r_frame_rate"])
1226+
num, den = metadata["r_frame_rate"].split("/")
1227+
encoded_frame_rate = int(num) / int(den)
1228+
assert encoded_frame_rate == frame_rate
1229+
# Assert the decoded frame rate matches the input frame rate
1230+
_decoded_frames, decoded_frame_rate = self.decode_and_get_frame_rate(
1231+
output_path
1232+
)
1233+
assert decoded_frame_rate == frame_rate

0 commit comments

Comments
 (0)