Skip to content

Commit 09e6f44

Browse files
committed
Oops, fix
1 parent b5f2df0 commit 09e6f44

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -855,14 +855,6 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
855855

856856
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
857857

858-
auto lastDecodedFrameIsPlayedAt = [&streamInfo](double seconds) {
859-
auto pts = secondsToClosestPts(seconds, streamInfo.timeBase);
860-
return (
861-
streamInfo.lastDecodedAvFramePts <= pts and
862-
pts <= streamInfo.lastDecodedAvFramePts +
863-
streamInfo.lastDecodedAvFrameDuration);
864-
};
865-
866858
// TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
867859
// We should remove it and seek back to the stream's beginning when needed.
868860
// See test_multiple_calls
@@ -871,8 +863,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
871863
streamInfo.lastDecodedAvFrameDuration == 0) ||
872864
(streamInfo.lastDecodedAvFramePts +
873865
streamInfo.lastDecodedAvFrameDuration <=
874-
secondsToClosestPts(startSeconds, streamInfo.timeBase)) ||
875-
!lastDecodedFrameIsPlayedAt(startSeconds),
866+
secondsToClosestPts(startSeconds, streamInfo.timeBase)),
876867
"The previous call's stop_seconds is larger than the current calls's start_seconds (roughly)");
877868

878869
setCursorPtsInSeconds(startSeconds);
@@ -883,7 +874,23 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
883874
std::vector<torch::Tensor> tensors;
884875

885876
bool reachedEOF = false;
886-
while (!lastDecodedFrameIsPlayedAt(stopSeconds) && !reachedEOF) {
877+
auto shouldStopDecoding = [&streamInfo, stopSeconds, &reachedEOF]() {
878+
if (reachedEOF) {
879+
return true;
880+
}
881+
// Return true iff stopSeconds is in [begin, end] of the last decoded frame.
882+
// We use [begin, end] and not [begin, end), which may seem counter
883+
// intuitive, but this actually ensures that stopSeconds is an open upper
884+
// bound, i.e. a frame that starts on stopSeconds won't be part of the
885+
// output.
886+
auto pts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
887+
auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
888+
streamInfo.lastDecodedAvFrameDuration;
889+
return (streamInfo.lastDecodedAvFramePts) <= pts &&
890+
(pts <= lastDecodedAvFrameEnd);
891+
};
892+
893+
while (!shouldStopDecoding()) {
887894
try {
888895
AVFrameStream avFrameStream =
889896
decodeAVFrame([&streamInfo](AVFrame* avFrame) {

test/decoders/test_ops.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ def get_reference_frames(start_seconds, stop_seconds):
750750
frames, get_reference_frames(start_seconds, stop_seconds)
751751
)
752752

753+
# "seeking" forward is OK
753754
start_seconds, stop_seconds = 3, 4
754755
frames = get_frames_by_pts_in_range_audio(
755756
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
@@ -771,7 +772,7 @@ def get_reference_frames(start_seconds, stop_seconds):
771772
frames, get_reference_frames(start_seconds, stop_seconds)
772773
)
773774

774-
# But starting immediately on the same frame isn't OK
775+
# but starting immediately on the same frame raises
775776
with pytest.raises(
776777
RuntimeError,
777778
match="The previous call's stop_seconds is larger than the current calls's start_seconds",
@@ -788,13 +789,13 @@ def get_reference_frames(start_seconds, stop_seconds):
788789
decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6
789790
)
790791

791-
start_seconds, stop_seconds = 0, 2
792-
frames = get_frames_by_pts_in_range_audio(
793-
decoder, start_seconds=start_seconds, stop_seconds=stop_seconds
794-
)
795-
with pytest.raises(AssertionError):
796-
torch.testing.assert_close(
797-
frames, get_reference_frames(start_seconds, stop_seconds)
792+
# and seeking backwards doesn't work either
793+
with pytest.raises(
794+
RuntimeError,
795+
match="The previous call's stop_seconds is larger than the current calls's start_seconds",
796+
):
797+
frames = get_frames_by_pts_in_range_audio(
798+
decoder, start_seconds=0, stop_seconds=2
798799
)
799800

800801

0 commit comments

Comments
 (0)