Skip to content

Commit 050d91a

Browse files
author
Daniel Flores
committed
to_filelike, update test
1 parent 45222dc commit 050d91a

File tree

4 files changed

+104
-6
lines changed

4 files changed

+104
-6
lines changed

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
2828
encode_video_to_file,
29+
encode_video_to_file_like,
2930
encode_video_to_tensor,
3031
get_ffmpeg_library_versions,
3132
get_frame_at_index,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
4141
m.def(
4242
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
43+
m.def(
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
4345
m.def(
4446
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4547
m.def(
@@ -606,6 +608,30 @@ at::Tensor encode_video_to_tensor(
606608
.encodeToTensor();
607609
}
608610

611+
void _encode_video_to_file_like(
612+
const at::Tensor& frames,
613+
int64_t frame_rate,
614+
std::string_view format,
615+
int64_t file_like_context,
616+
std::optional<int64_t> crf = std::nullopt) {
617+
auto fileLikeContext =
618+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
619+
TORCH_CHECK(
620+
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
621+
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
622+
623+
VideoStreamOptions videoStreamOptions;
624+
videoStreamOptions.crf = crf;
625+
626+
VideoEncoder encoder(
627+
frames,
628+
validateInt64ToInt(frame_rate, "frame_rate"),
629+
format,
630+
std::move(avioContextHolder),
631+
videoStreamOptions);
632+
encoder.encode();
633+
}
634+
609635
// For testing only. We need to implement this operation as a core library
610636
// function because what we're testing is round-tripping pts values as
611637
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -870,6 +896,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
870896
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
871897
m.impl("encode_video_to_file", &encode_video_to_file);
872898
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
899+
m.impl("_encode_video_to_file_like", &_encode_video_to_file_like);
873900
m.impl("seek_to_pts", &seek_to_pts);
874901
m.impl("add_video_stream", &add_video_stream);
875902
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def load_torchcodec_shared_libraries():
104104
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
105105
torch.ops.torchcodec_ns.encode_video_to_tensor.default
106106
)
107+
_encode_video_to_file_like = torch._dynamo.disallow_in_graph(
108+
torch.ops.torchcodec_ns._encode_video_to_file_like.default
109+
)
107110
create_from_tensor = torch._dynamo.disallow_in_graph(
108111
torch.ops.torchcodec_ns.create_from_tensor.default
109112
)
@@ -203,6 +206,33 @@ def encode_audio_to_file_like(
203206
)
204207

205208

209+
def encode_video_to_file_like(
210+
frames: torch.Tensor,
211+
frame_rate: int,
212+
format: str,
213+
file_like: Union[io.RawIOBase, io.BufferedIOBase],
214+
crf: Optional[int] = None,
215+
) -> None:
216+
"""Encode video frames to a file-like object.
217+
218+
Args:
219+
frames: Video frames tensor
220+
frame_rate: Frame rate in frames per second
221+
format: Video format (e.g., "mp4", "mov", "mkv")
222+
file_like: File-like object that supports write() and seek() methods
223+
crf: Optional constant rate factor for encoding quality
224+
"""
225+
assert _pybind_ops is not None
226+
227+
_encode_video_to_file_like(
228+
frames,
229+
frame_rate,
230+
format,
231+
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
232+
crf,
233+
)
234+
235+
206236
def get_frames_at_indices(
207237
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
208238
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -302,6 +332,17 @@ def encode_video_to_tensor_abstract(
302332
return torch.empty([], dtype=torch.long)
303333

304334

335+
@register_fake("torchcodec_ns::_encode_video_to_file_like")
336+
def _encode_video_to_file_like_abstract(
337+
frames: torch.Tensor,
338+
frame_rate: int,
339+
format: str,
340+
file_like_context: int,
341+
crf: Optional[int] = None,
342+
) -> None:
343+
return
344+
345+
305346
@register_fake("torchcodec_ns::create_from_tensor")
306347
def create_from_tensor_abstract(
307348
video_tensor: torch.Tensor, seek_mode: Optional[str]

test/test_ops.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
create_from_tensor,
3030
encode_audio_to_file,
3131
encode_video_to_file,
32+
encode_video_to_file_like,
3233
encode_video_to_tensor,
3334
get_ffmpeg_library_versions,
3435
get_frame_at_index,
@@ -1397,7 +1398,7 @@ def decode(self, source=None) -> torch.Tensor:
13971398
@pytest.mark.parametrize(
13981399
"format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
13991400
)
1400-
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
1401+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
14011402
def test_video_encoder_round_trip(self, tmp_path, format, method):
14021403
# Test that decode(encode(decode(frames))) == decode(frames)
14031404
ffmpeg_version = get_ffmpeg_major_version()
@@ -1424,11 +1425,23 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
14241425
**params,
14251426
)
14261427
round_trip_frames = self.decode(encoded_path).data
1427-
else: # to_tensor
1428+
elif method == "to_tensor":
14281429
encoded_tensor = encode_video_to_tensor(
14291430
source_frames, format=format, **params
14301431
)
14311432
round_trip_frames = self.decode(encoded_tensor).data
1433+
elif method == "to_file_like":
1434+
file_like = io.BytesIO()
1435+
encode_video_to_file_like(
1436+
frames=source_frames,
1437+
format=format,
1438+
file_like=file_like,
1439+
**params,
1440+
)
1441+
file_like.seek(0)
1442+
round_trip_frames = self.decode(file_like).data
1443+
else:
1444+
raise ValueError(f"Unknown method: {method}")
14321445

14331446
assert source_frames.shape == round_trip_frames.shape
14341447
assert source_frames.dtype == round_trip_frames.dtype
@@ -1445,6 +1458,7 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
14451458
assert psnr(s_frame, rt_frame) > 30
14461459
assert_close(s_frame, rt_frame, atol=atol, rtol=0)
14471460

1461+
@pytest.mark.slow
14481462
@pytest.mark.parametrize(
14491463
"format",
14501464
(
@@ -1457,8 +1471,9 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
14571471
pytest.param("webm", marks=pytest.mark.slow),
14581472
),
14591473
)
1460-
def test_against_to_file(self, tmp_path, format):
1461-
# Test that to_file and to_tensor produce the same results
1474+
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
1475+
def test_against_to_file(self, tmp_path, format, method):
1476+
# Test that to_file, to_tensor, and to_file_like produce the same results
14621477
ffmpeg_version = get_ffmpeg_major_version()
14631478
if format == "webm" and (
14641479
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
@@ -1470,11 +1485,25 @@ def test_against_to_file(self, tmp_path, format):
14701485

14711486
encoded_file = tmp_path / f"output.{format}"
14721487
encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params)
1473-
encoded_tensor = encode_video_to_tensor(source_frames, format=format, **params)
1488+
1489+
if method == "to_tensor":
1490+
encoded_output = encode_video_to_tensor(
1491+
source_frames, format=format, **params
1492+
)
1493+
else: # to_file_like
1494+
file_like = io.BytesIO()
1495+
encode_video_to_file_like(
1496+
frames=source_frames,
1497+
file_like=file_like,
1498+
format=format,
1499+
**params,
1500+
)
1501+
file_like.seek(0)
1502+
encoded_output = file_like
14741503

14751504
torch.testing.assert_close(
14761505
self.decode(encoded_file).data,
1477-
self.decode(encoded_tensor).data,
1506+
self.decode(encoded_output).data,
14781507
atol=0,
14791508
rtol=0,
14801509
)

0 commit comments

Comments
 (0)