@@ -876,24 +876,9 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
876876 // sample rate, so in theory we know the number of output samples.
877877 std::vector<torch::Tensor> tensors;
878878
879- bool reachedEOF = false ;
880- auto shouldStopDecoding = [&streamInfo, stopSeconds, &reachedEOF]() {
881- if (reachedEOF) {
882- return true ;
883- }
884- // Return true iff stopSeconds is in [begin, end] of the last decoded frame.
885- // We use [begin, end] and not [begin, end), which may seem counter
886- // intuitive, but this actually ensures that stopSeconds is an open upper
887- // bound, i.e. a frame that starts on stopSeconds won't be part of the
888- // output.
889- auto pts = secondsToClosestPts (stopSeconds, streamInfo.timeBase );
890- auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
891- streamInfo.lastDecodedAvFrameDuration ;
892- return (streamInfo.lastDecodedAvFramePts ) <= pts &&
893- (pts <= lastDecodedAvFrameEnd);
894- };
895-
896- while (!shouldStopDecoding ()) {
879+ auto stopPts = secondsToClosestPts (stopSeconds, streamInfo.timeBase );
880+ auto shouldStopDecoding = false ;
881+ while (!shouldStopDecoding) {
897882 try {
898883 AVFrameStream avFrameStream =
899884 decodeAVFrame ([&streamInfo](AVFrame* avFrame) {
@@ -904,8 +889,17 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio(
904889 auto frameOutput = convertAVFrameToFrameOutput (avFrameStream);
905890 tensors.push_back (frameOutput.data );
906891 } catch (const EndOfFileException& e) {
907- reachedEOF = true ;
892+ shouldStopDecoding = true ;
908893 }
894+
895+ // If stopSeconds is in [begin, end] of the last decoded frame, we should
896+ // stop decoding more frames. Note that if we were to use [begin, end),
897+ // which may seem more natural, then we would decode the frame starting at
898+ // stopSeconds, which isn't what we want!
899+ auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
900+ streamInfo.lastDecodedAvFrameDuration ;
901+ shouldStopDecoding |= (streamInfo.lastDecodedAvFramePts ) <= stopPts &&
902+ (stopPts <= lastDecodedAvFrameEnd);
909903 }
910904 return torch::cat (tensors, 1 );
911905}
0 commit comments