Skip to content

Commit 981a27e

Browse files
author
Daniel Flores
committed
to_filelike, update test
1 parent 80d3999 commit 981a27e

File tree

4 files changed

+104
-6
lines changed

4 files changed

+104
-6
lines changed

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
2828
encode_video_to_file,
29+
encode_video_to_file_like,
2930
encode_video_to_tensor,
3031
get_ffmpeg_library_versions,
3132
get_frame_at_index,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
4141
m.def(
4242
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
43+
m.def(
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
4345
m.def(
4446
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4547
m.def(
@@ -606,6 +608,30 @@ at::Tensor encode_video_to_tensor(
606608
.encodeToTensor();
607609
}
608610

611+
void _encode_video_to_file_like(
612+
const at::Tensor& frames,
613+
int64_t frame_rate,
614+
std::string_view format,
615+
int64_t file_like_context,
616+
std::optional<int64_t> crf = std::nullopt) {
617+
auto fileLikeContext =
618+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
619+
TORCH_CHECK(
620+
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
621+
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
622+
623+
VideoStreamOptions videoStreamOptions;
624+
videoStreamOptions.crf = crf;
625+
626+
VideoEncoder encoder(
627+
frames,
628+
validateInt64ToInt(frame_rate, "frame_rate"),
629+
format,
630+
std::move(avioContextHolder),
631+
videoStreamOptions);
632+
encoder.encode();
633+
}
634+
609635
// For testing only. We need to implement this operation as a core library
610636
// function because what we're testing is round-tripping pts values as
611637
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -870,6 +896,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
870896
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
871897
m.impl("encode_video_to_file", &encode_video_to_file);
872898
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
899+
m.impl("_encode_video_to_file_like", &_encode_video_to_file_like);
873900
m.impl("seek_to_pts", &seek_to_pts);
874901
m.impl("add_video_stream", &add_video_stream);
875902
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def load_torchcodec_shared_libraries():
104104
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
105105
torch.ops.torchcodec_ns.encode_video_to_tensor.default
106106
)
107+
_encode_video_to_file_like = torch._dynamo.disallow_in_graph(
108+
torch.ops.torchcodec_ns._encode_video_to_file_like.default
109+
)
107110
create_from_tensor = torch._dynamo.disallow_in_graph(
108111
torch.ops.torchcodec_ns.create_from_tensor.default
109112
)
@@ -203,6 +206,33 @@ def encode_audio_to_file_like(
203206
)
204207

205208

209+
def encode_video_to_file_like(
210+
frames: torch.Tensor,
211+
frame_rate: int,
212+
format: str,
213+
file_like: Union[io.RawIOBase, io.BufferedIOBase],
214+
crf: Optional[int] = None,
215+
) -> None:
216+
"""Encode video frames to a file-like object.
217+
218+
Args:
219+
frames: Video frames tensor
220+
frame_rate: Frame rate in frames per second
221+
format: Video format (e.g., "mp4", "mov", "mkv")
222+
file_like: File-like object that supports write() and seek() methods
223+
crf: Optional constant rate factor for encoding quality
224+
"""
225+
assert _pybind_ops is not None
226+
227+
_encode_video_to_file_like(
228+
frames,
229+
frame_rate,
230+
format,
231+
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
232+
crf,
233+
)
234+
235+
206236
def get_frames_at_indices(
207237
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
208238
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -302,6 +332,17 @@ def encode_video_to_tensor_abstract(
302332
return torch.empty([], dtype=torch.long)
303333

304334

335+
@register_fake("torchcodec_ns::_encode_video_to_file_like")
336+
def _encode_video_to_file_like_abstract(
337+
frames: torch.Tensor,
338+
frame_rate: int,
339+
format: str,
340+
file_like_context: int,
341+
crf: Optional[int] = None,
342+
) -> None:
343+
return
344+
345+
305346
@register_fake("torchcodec_ns::create_from_tensor")
306347
def create_from_tensor_abstract(
307348
video_tensor: torch.Tensor, seek_mode: Optional[str]

test/test_ops.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
create_from_tensor,
3030
encode_audio_to_file,
3131
encode_video_to_file,
32+
encode_video_to_file_like,
3233
encode_video_to_tensor,
3334
get_ffmpeg_library_versions,
3435
get_frame_at_index,
@@ -1394,7 +1395,7 @@ def decode(self, source=None) -> torch.Tensor:
13941395
return VideoDecoder(source).get_frames_in_range(start=0, stop=60)
13951396

13961397
@pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm"))
1397-
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
1398+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
13981399
def test_video_encoder_round_trip(self, tmp_path, format, method):
13991400
# Test that decode(encode(decode(frames))) == decode(frames)
14001401
ffmpeg_version = get_ffmpeg_major_version()
@@ -1421,11 +1422,23 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
14211422
**params,
14221423
)
14231424
round_trip_frames = self.decode(encoded_path).data
1424-
else: # to_tensor
1425+
elif method == "to_tensor":
1426+
format = "matroska" if format == "mkv" else format
14251427
encoded_tensor = encode_video_to_tensor(
14261428
source_frames, format=format, **params
14271429
)
14281430
round_trip_frames = self.decode(encoded_tensor).data
1431+
else: # to_file_like
1432+
format = "matroska" if format == "mkv" else format
1433+
file_like = io.BytesIO()
1434+
encode_video_to_file_like(
1435+
frames=source_frames,
1436+
format=format,
1437+
file_like=file_like,
1438+
**params,
1439+
)
1440+
file_like.seek(0)
1441+
round_trip_frames = self.decode(file_like).data
14291442

14301443
assert source_frames.shape == round_trip_frames.shape
14311444
assert source_frames.dtype == round_trip_frames.dtype
@@ -1442,11 +1455,13 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
14421455
assert psnr(s_frame, rt_frame) > 30
14431456
assert_close(s_frame, rt_frame, atol=atol, rtol=0)
14441457

1458+
@pytest.mark.slow
14451459
@pytest.mark.parametrize(
14461460
"format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif")
14471461
)
1448-
def test_against_to_file(self, tmp_path, format):
1449-
# Test that to_file and to_tensor produce the same results
1462+
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
1463+
def test_against_to_file(self, tmp_path, format, method):
1464+
# Test that to_file, to_tensor, and to_file_like produce the same results
14501465
ffmpeg_version = get_ffmpeg_major_version()
14511466
if format == "webm" and (
14521467
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
@@ -1458,10 +1473,24 @@ def test_against_to_file(self, tmp_path, format):
14581473

14591474
encoded_file = tmp_path / f"output.{format}"
14601475
encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params)
1461-
encoded_tensor = encode_video_to_tensor(source_frames, format=format, **params)
1476+
1477+
if method == "to_tensor":
1478+
encoded_output = encode_video_to_tensor(
1479+
source_frames, format=format, **params
1480+
)
1481+
else: # to_file_like
1482+
file_like = io.BytesIO()
1483+
encode_video_to_file_like(
1484+
frames=source_frames,
1485+
file_like=file_like,
1486+
format=format,
1487+
**params,
1488+
)
1489+
file_like.seek(0)
1490+
encoded_output = file_like
14621491

14631492
torch.testing.assert_close(
1464-
self.decode(encoded_file).data, self.decode(encoded_tensor).data
1493+
self.decode(encoded_file).data, self.decode(encoded_output).data
14651494
)
14661495

14671496
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")

0 commit comments

Comments
 (0)