Skip to content

Commit f3b56f8

Browse files
committed
Add proper error when backward seek is neede
1 parent d2357fe commit f3b56f8

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -854,14 +854,25 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
854854

855855
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
856856

857-
auto lastDecodedFrameIsPlayedAtStopSeconds =
858-
[this, &streamInfo, stopSeconds]() {
859-
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
860-
return (
861-
streamInfo.lastDecodedAvFramePts <= stopPts and
862-
stopPts <= streamInfo.lastDecodedAvFramePts +
863-
streamInfo.lastDecodedAvFrameDuration);
864-
};
857+
auto lastDecodedFrameIsPlayedAt = [this, &streamInfo](double seconds) {
858+
auto pts = secondsToClosestPts(seconds, streamInfo.timeBase);
859+
return (
860+
streamInfo.lastDecodedAvFramePts <= pts and
861+
pts <= streamInfo.lastDecodedAvFramePts +
862+
streamInfo.lastDecodedAvFrameDuration);
863+
};
864+
865+
// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
866+
// We should remove it and seek back to the stream's beginning when needed.
867+
// See test_multiple_calls
868+
TORCH_CHECK(
869+
(streamInfo.lastDecodedAvFramePts == 0 &&
870+
streamInfo.lastDecodedAvFrameDuration == 0) ||
871+
(streamInfo.lastDecodedAvFramePts +
872+
streamInfo.lastDecodedAvFrameDuration <=
873+
secondsToClosestPts(startSeconds, streamInfo.timeBase)) ||
874+
!lastDecodedFrameIsPlayedAt(startSeconds),
875+
"The previous call's stop_seconds is larger than the current calls's start_seconds (roughly)");
865876

866877
setCursorPtsInSeconds(startSeconds);
867878

@@ -871,7 +882,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
871882
std::vector<torch::Tensor> tensors;
872883

873884
bool reachedEOF = false;
874-
while (!lastDecodedFrameIsPlayedAtStopSeconds() && !reachedEOF) {
885+
while (!lastDecodedFrameIsPlayedAt(stopSeconds) && !reachedEOF) {
875886
try {
876887
AVFrameStream avFrameStream =
877888
decodeAVFrame([&streamInfo](AVFrame* avFrame) {

test/decoders/test_ops.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,13 @@ def test_decode_just_one_frame_at_boundaries(self, asset, expected_shape):
725725

726726
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
727727
def test_multiple_calls(self, asset):
728+
# Ensure that multiple calls are OK as long as we're decoding
729+
# "sequentially", i.e. we don't require a backwards seek.
730+
# And ensure a proper error is raised in such case.
731+
# TODO-AUDIO We shouldn't error, we should just implement the seeking
732+
# back to the beginning of the stream.
728733

729-
def decode_stateless(start_seconds, stop_seconds):
734+
def get_reference_frames(start_seconds, stop_seconds):
730735
decoder = create_from_file(str(asset.path), seek_mode="approximate")
731736
add_audio_stream(decoder)
732737

@@ -742,25 +747,56 @@ def decode_stateless(start_seconds, stop_seconds):
742747
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
743748
)
744749
torch.testing.assert_close(
745-
frames, decode_stateless(start_seconds, stop_seconds)
750+
frames, get_reference_frames(start_seconds, stop_seconds)
746751
)
747752

748753
start_seconds, stop_seconds = 3, 4
749754
frames = get_frames_by_pts_in_range_audio(
750755
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
751756
)
752757
torch.testing.assert_close(
753-
frames, decode_stateless(start_seconds, stop_seconds)
758+
frames, get_reference_frames(start_seconds, stop_seconds)
754759
)
755760

756-
# TODO-AUDIO
761+
# Starting at the frame immediately after the previous one is OK
762+
index_of_frame_at_4 = asset.get_frame_index(pts_seconds=4)
763+
start_seconds, stop_seconds = (
764+
asset.frames[asset.default_stream_index][
765+
index_of_frame_at_4 + 1
766+
].pts_seconds,
767+
5,
768+
)
769+
frames = get_frames_by_pts_in_range_audio(
770+
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
771+
)
772+
torch.testing.assert_close(
773+
frames, get_reference_frames(start_seconds, stop_seconds)
774+
)
775+
776+
# But starting immediately on the same frame isn't OK
777+
with pytest.raises(
778+
RuntimeError,
779+
match="The previous call's stop_seconds is larger than the current calls's start_seconds",
780+
):
781+
get_frames_by_pts_in_range_audio(
782+
decoder, start_seconds=stop_seconds, stop_seconds=6
783+
)
784+
785+
with pytest.raises(
786+
RuntimeError,
787+
match="The previous call's stop_seconds is larger than the current calls's start_seconds",
788+
):
789+
get_frames_by_pts_in_range_audio(
790+
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
791+
)
792+
757793
start_seconds, stop_seconds = 0, 2
758794
frames = get_frames_by_pts_in_range_audio(
759795
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
760796
)
761797
with pytest.raises(AssertionError):
762798
torch.testing.assert_close(
763-
frames, decode_stateless(start_seconds, stop_seconds)
799+
frames, get_reference_frames(start_seconds, stop_seconds)
764800
)
765801

766802

test/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def assert_tensor_close_on_at_least(actual_tensor, ref_tensor, *, percentage, at
7373
)
7474

7575

76-
7776
def in_fbcode() -> bool:
7877
return os.environ.get("IN_FBCODE_TORCHCODEC") == "1"
7978

0 commit comments

Comments
 (0)