@@ -855,14 +855,6 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
855855
856856 StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
857857
858- auto lastDecodedFrameIsPlayedAt = [&streamInfo](double seconds) {
859- auto pts = secondsToClosestPts (seconds, streamInfo.timeBase );
860- return (
861- streamInfo.lastDecodedAvFramePts <= pts and
862- pts <= streamInfo.lastDecodedAvFramePts +
863- streamInfo.lastDecodedAvFrameDuration );
864- };
865-
866858 // TODO-AUDIO This essentially enforce that we don't need to seek (backwards).
867859 // We should remove it and seek back to the stream's beginning when needed.
868860 // See test_multiple_calls
@@ -871,8 +863,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
871863 streamInfo.lastDecodedAvFrameDuration == 0 ) ||
872864 (streamInfo.lastDecodedAvFramePts +
873865 streamInfo.lastDecodedAvFrameDuration <=
874- secondsToClosestPts (startSeconds, streamInfo.timeBase )) ||
875- !lastDecodedFrameIsPlayedAt (startSeconds),
866+ secondsToClosestPts (startSeconds, streamInfo.timeBase )),
876867 " The previous call's stop_seconds is larger than the current calls's start_seconds (roughly)" );
877868
878869 setCursorPtsInSeconds (startSeconds);
@@ -883,7 +874,23 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
883874 std::vector<torch::Tensor> tensors;
884875
885876 bool reachedEOF = false ;
886- while (!lastDecodedFrameIsPlayedAt (stopSeconds) && !reachedEOF) {
877+ auto shouldStopDecoding = [&streamInfo, stopSeconds, &reachedEOF]() {
878+ if (reachedEOF) {
879+ return true ;
880+ }
881+ // Return true iff stopSeconds is in [begin, end] of the last decoded frame.
882+ // We use [begin, end] and not [begin, end), which may seem counter
883+ // intuitive, but this actually ensures that stopSeconds is an open upper
884+ // bound, i.e. a frame that starts on stopSeconds won't be part of the
885+ // output.
886+ auto pts = secondsToClosestPts (stopSeconds, streamInfo.timeBase );
887+ auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
888+ streamInfo.lastDecodedAvFrameDuration ;
889+ return (streamInfo.lastDecodedAvFramePts ) <= pts &&
890+ (pts <= lastDecodedAvFrameEnd);
891+ };
892+
893+ while (!shouldStopDecoding ()) {
887894 try {
888895 AVFrameStream avFrameStream =
889896 decodeAVFrame ([&streamInfo](AVFrame* avFrame) {
0 commit comments