Skip to content

Commit fe04cd2

Browse files
committed
Add support for None stop_seconds
1 parent f4bed23 commit fe04cd2

File tree

8 files changed

+29
-19
lines changed

8 files changed

+29
-19
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstdint>
99
#include <cstdio>
1010
#include <iostream>
11+
#include <limits>
1112
#include <sstream>
1213
#include <stdexcept>
1314
#include <string_view>
@@ -840,7 +841,9 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
840841

841842
torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
842843
double startSeconds,
843-
double stopSeconds) {
844+
std::optional<double> _stopSeconds) {
845+
auto stopSeconds = _stopSeconds.value_or(std::numeric_limits<double>::max());
846+
844847
TORCH_CHECK(
845848
startSeconds <= stopSeconds,
846849
"Start seconds (" + std::to_string(startSeconds) +

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class VideoDecoder {
223223

224224
torch::Tensor getFramesPlayedInRangeAudio(
225225
double startSeconds,
226-
double stopSeconds);
226+
std::optional<double> _stopSeconds = std::nullopt);
227227

228228
class EndOfFileException : public std::runtime_error {
229229
public:

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ namespace facebook::torchcodec {
2525
// https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme
2626
TORCH_LIBRARY(torchcodec_ns, m) {
2727
m.impl_abstract_pystub(
28-
"torchcodec.decoders._core.video_decoder_ops",
29-
"//pytorch/torchcodec:torchcodec");
28+
"torchcodec.decoders._core.ops", "//pytorch/torchcodec:torchcodec");
3029
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3130
m.def(
3231
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
@@ -49,7 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4948
m.def(
5049
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
5150
m.def(
52-
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> Tensor");
51+
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> Tensor");
5352
m.def(
5453
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
5554
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
@@ -308,7 +307,7 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
308307
torch::Tensor get_frames_by_pts_in_range_audio(
309308
at::Tensor& decoder,
310309
double start_seconds,
311-
double stop_seconds) {
310+
std::optional<double> stop_seconds) {
312311
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
313312
return videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
314313
}

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
122122
torch::Tensor get_frames_by_pts_in_range_audio(
123123
at::Tensor& decoder,
124124
double start_seconds,
125-
double stop_seconds);
125+
std::optional<double> stop_seconds = std::nullopt);
126126

127127
// For testing only. We need to implement this operation as a core library
128128
// function because what we're testing is round-tripping pts values as

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
get_container_metadata_from_header,
1313
VideoStreamMetadata,
1414
)
15-
from .video_decoder_ops import (
15+
from .ops import (
1616
_add_video_stream,
1717
_get_key_frame_indices,
1818
_test_frame_pts_equality,

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414

15-
from torchcodec.decoders._core.video_decoder_ops import (
15+
from torchcodec.decoders._core.ops import (
1616
_get_container_json_metadata,
1717
_get_stream_json_metadata,
1818
create_from_file,

src/torchcodec/decoders/_core/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def get_frames_by_pts_in_range_audio_abstract(
270270
decoder: torch.Tensor,
271271
*,
272272
start_seconds: float,
273-
stop_seconds: float,
273+
stop_seconds: Optional[float] = None,
274274
) -> torch.Tensor:
275275
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
276276
return torch.empty(image_size)

test/decoders/test_ops.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ def test_audio_bad_seek_mode(self):
646646
"range",
647647
(
648648
"begin_to_end",
649+
"begin_to_None",
649650
"begin_to_beyond_end",
650651
"at_frame_boundaries",
651652
"not_at_frame_boundaries",
@@ -655,6 +656,8 @@ def test_audio_bad_seek_mode(self):
655656
def test_get_frames_by_pts_in_range_audio(self, range, asset):
656657
if range == "begin_to_end":
657658
start_seconds, stop_seconds = 0, asset.duration_seconds
659+
elif range == "begin_to_None":
660+
start_seconds, stop_seconds = 0, None
658661
elif range == "begin_to_beyond_end":
659662
start_seconds, stop_seconds = 0, asset.duration_seconds + 10
660663
elif range == "at_frame_boundaries":
@@ -671,18 +674,23 @@ def test_get_frames_by_pts_in_range_audio(self, range, asset):
671674
stop_frame_info.duration_seconds / 2
672675
)
673676

674-
decoder = create_from_file(str(asset.path), seek_mode="approximate")
675-
add_audio_stream(decoder)
676-
677-
# stop_offset logic: if stop_seconds is at a frame boundary i.e. when a
678-
# frame starts, then that frame should *not* be included in the output.
679-
# Otherwise, it should be part of it, hence why we add 1 to `stop=`.
680-
stop_offset = 0 if range == "at_frame_boundaries" else 1
677+
ref_start_index = asset.get_frame_index(pts_seconds=start_seconds)
678+
if range == "begin_to_None":
679+
ref_stop_index = (
680+
asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
681+
)
682+
elif range == "at_frame_boundaries":
683+
ref_stop_index = asset.get_frame_index(pts_seconds=stop_seconds)
684+
else:
685+
ref_stop_index = asset.get_frame_index(pts_seconds=stop_seconds) + 1
681686
reference_frames = asset.get_frame_data_by_range(
682-
start=asset.get_frame_index(pts_seconds=start_seconds),
683-
stop=asset.get_frame_index(pts_seconds=stop_seconds) + stop_offset,
687+
start=ref_start_index,
688+
stop=ref_stop_index,
684689
)
685690

691+
decoder = create_from_file(str(asset.path), seek_mode="approximate")
692+
add_audio_stream(decoder)
693+
686694
frames = get_frames_by_pts_in_range_audio(
687695
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
688696
)

0 commit comments

Comments
 (0)