@@ -337,12 +337,12 @@ void VideoDecoder::createFilterGraph(
337337 StreamInfo& streamInfo,
338338 int expectedOutputHeight,
339339 int expectedOutputWidth) {
340- FilterState& filterState = streamInfo.filterState ;
341- filterState .filterGraph .reset (avfilter_graph_alloc ());
342- TORCH_CHECK (filterState .filterGraph .get () != nullptr );
340+ FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext ;
341+ filterGraphContext .filterGraph .reset (avfilter_graph_alloc ());
342+ TORCH_CHECK (filterGraphContext .filterGraph .get () != nullptr );
343343
344344 if (streamInfo.videoStreamOptions .ffmpegThreadCount .has_value ()) {
345- filterState .filterGraph ->nb_threads =
345+ filterGraphContext .filterGraph ->nb_threads =
346346 streamInfo.videoStreamOptions .ffmpegThreadCount .value ();
347347 }
348348
@@ -360,25 +360,25 @@ void VideoDecoder::createFilterGraph(
360360 << codecContext->sample_aspect_ratio .den ;
361361
362362 int ffmpegStatus = avfilter_graph_create_filter (
363- &filterState .sourceContext ,
363+ &filterGraphContext .sourceContext ,
364364 buffersrc,
365365 " in" ,
366366 filterArgs.str ().c_str (),
367367 nullptr ,
368- filterState .filterGraph .get ());
368+ filterGraphContext .filterGraph .get ());
369369 if (ffmpegStatus < 0 ) {
370370 throw std::runtime_error (
371371 std::string (" Failed to create filter graph: " ) + filterArgs.str () +
372372 " : " + getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
373373 }
374374
375375 ffmpegStatus = avfilter_graph_create_filter (
376- &filterState .sinkContext ,
376+ &filterGraphContext .sinkContext ,
377377 buffersink,
378378 " out" ,
379379 nullptr ,
380380 nullptr ,
381- filterState .filterGraph .get ());
381+ filterGraphContext .filterGraph .get ());
382382 if (ffmpegStatus < 0 ) {
383383 throw std::runtime_error (
384384 " Failed to create filter graph: " +
@@ -388,7 +388,7 @@ void VideoDecoder::createFilterGraph(
388388 enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
389389
390390 ffmpegStatus = av_opt_set_int_list (
391- filterState .sinkContext ,
391+ filterGraphContext .sinkContext ,
392392 " pix_fmts" ,
393393 pix_fmts,
394394 AV_PIX_FMT_NONE,
@@ -403,11 +403,11 @@ void VideoDecoder::createFilterGraph(
403403 UniqueAVFilterInOut inputs (avfilter_inout_alloc ());
404404
405405 outputs->name = av_strdup (" in" );
406- outputs->filter_ctx = filterState .sourceContext ;
406+ outputs->filter_ctx = filterGraphContext .sourceContext ;
407407 outputs->pad_idx = 0 ;
408408 outputs->next = nullptr ;
409409 inputs->name = av_strdup (" out" );
410- inputs->filter_ctx = filterState .sinkContext ;
410+ inputs->filter_ctx = filterGraphContext .sinkContext ;
411411 inputs->pad_idx = 0 ;
412412 inputs->next = nullptr ;
413413
@@ -418,7 +418,7 @@ void VideoDecoder::createFilterGraph(
418418 AVFilterInOut* outputsTmp = outputs.release ();
419419 AVFilterInOut* inputsTmp = inputs.release ();
420420 ffmpegStatus = avfilter_graph_parse_ptr (
421- filterState .filterGraph .get (),
421+ filterGraphContext .filterGraph .get (),
422422 description.str ().c_str (),
423423 &inputsTmp,
424424 &outputsTmp,
@@ -431,7 +431,8 @@ void VideoDecoder::createFilterGraph(
431431 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
432432 }
433433
434- ffmpegStatus = avfilter_graph_config (filterState.filterGraph .get (), nullptr );
434+ ffmpegStatus =
435+ avfilter_graph_config (filterGraphContext.filterGraph .get (), nullptr );
435436 if (ffmpegStatus < 0 ) {
436437 throw std::runtime_error (
437438 " Failed to configure filter graph: " +
@@ -803,16 +804,20 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
803804 }
804805}
805806
806- VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction (
807+ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
807808 std::function<bool (int , AVFrame*)> filterFunction) {
808809 if (activeStreamIndices_.size () == 0 ) {
809810 throw std::runtime_error (" No active streams configured." );
810811 }
812+
811813 resetDecodeStats ();
814+
815+ // Seek if needed.
812816 if (desiredPtsSeconds_.has_value ()) {
813817 maybeSeekToBeforeDesiredPts ();
814818 desiredPtsSeconds_ = std::nullopt ;
815819 }
820+
816821 // Need to get the next frame or error from PopFrame.
817822 UniqueAVFrame avFrame (av_frame_alloc ());
818823 AutoAVPacket autoAVPacket;
@@ -822,42 +827,58 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
822827 while (true ) {
823828 frameStreamIndex = -1 ;
824829 bool gotPermanentErrorOnAnyActiveStream = false ;
830+
831+ // Get a frame on an active stream. Note that we don't know ahead of time
832+ // which streams have frames to receive, so we linearly try the active
833+ // streams.
825834 for (int streamIndex : activeStreamIndices_) {
826835 StreamInfo& streamInfo = streamInfos_[streamIndex];
827836 ffmpegStatus =
828837 avcodec_receive_frame (streamInfo.codecContext .get (), avFrame.get ());
829- bool gotNonRetriableError =
830- ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN);
831- if (gotNonRetriableError) {
838+
839+ if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR (EAGAIN)) {
832840 gotPermanentErrorOnAnyActiveStream = true ;
833841 break ;
834842 }
843+
835844 if (ffmpegStatus == AVSUCCESS) {
836845 frameStreamIndex = streamIndex;
837846 break ;
838847 }
839848 }
849+
840850 if (gotPermanentErrorOnAnyActiveStream) {
841851 break ;
842852 }
853+
843854 decodeStats_.numFramesReceivedByDecoder ++;
844- bool gotNeededFrame = ffmpegStatus == AVSUCCESS &&
845- filterFunction (frameStreamIndex, avFrame.get ());
846- if (gotNeededFrame) {
855+
856+ // Is this the kind of frame we're looking for?
857+ if (ffmpegStatus == AVSUCCESS &&
858+ filterFunction (frameStreamIndex, avFrame.get ())) {
859+ // Yes, this is the frame we'll return; break out of the decoding loop.
847860 break ;
848861 } else if (ffmpegStatus == AVSUCCESS) {
849- // No need to send more packets here as the decoder may have frames in
850- // its buffer.
862+ // No, but we received a valid frame - just not the kind we're looking
863+ // for. The logic below will read packets and send them to the decoder.
864+ // But since we did just receive a frame, we should skip reading more
865+ // packets and sending them to the decoder and just try to receive more
866+ // frames from the decoder.
851867 continue ;
852868 }
869+
853870 if (reachedEOF) {
854871 // We don't have any more packets to send to the decoder. So keep on
855872 // pulling frames from its internal buffers.
856873 continue ;
857874 }
875+
876+ // We still haven't found the frame we're looking for. So let's read more
877+ // packets and send them to the decoder.
858878 ReferenceAVPacket packet (autoAVPacket);
859879 ffmpegStatus = av_read_frame (formatContext_.get (), packet.get ());
860880 decodeStats_.numPacketsRead ++;
881+
861882 if (ffmpegStatus == AVERROR_EOF) {
862883 // End of file reached. We must drain all codecs by sending a nullptr
863884 // packet.
@@ -872,27 +893,38 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
872893 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
873894 }
874895 }
896+
897+ // We've reached the end of file so we can't read any more packets from
898+ // it, but the decoder may still have frames to read in its buffer.
899+ // Continue iterating to try reading frames.
875900 reachedEOF = true ;
876901 continue ;
877902 }
903+
878904 if (ffmpegStatus < AVSUCCESS) {
879905 throw std::runtime_error (
880906 " Could not read frame from input file: " +
881907 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
882908 }
909+
883910 if (activeStreamIndices_.count (packet->stream_index ) == 0 ) {
884911 // This packet is not for any of the active streams.
885912 continue ;
886913 }
914+
915+ // We got a valid packet. Send it to the decoder, and we'll receive it in
916+ // the next iteration.
887917 ffmpegStatus = avcodec_send_packet (
888918 streamInfos_[packet->stream_index ].codecContext .get (), packet.get ());
889919 if (ffmpegStatus < AVSUCCESS) {
890920 throw std::runtime_error (
891921 " Could not push packet to decoder: " +
892922 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
893923 }
924+
894925 decodeStats_.numPacketsSentToDecoder ++;
895926 }
927+
896928 if (ffmpegStatus < AVSUCCESS) {
897929 if (reachedEOF || ffmpegStatus == AVERROR_EOF) {
898930 throw VideoDecoder::EndOfFileException (
@@ -903,6 +935,7 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
903935 " Could not receive frame from decoder: " +
904936 getFFMPEGErrorStringFromErrorCode (ffmpegStatus));
905937 }
938+
906939 // Note that we don't flush the decoder when we reach EOF (even though that's
907940 // mentioned in https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html).
908941 // This is because we may have packets internally in the decoder that we
@@ -912,10 +945,8 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
912945 StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
913946 activeStreamInfo.currentPts = avFrame->pts ;
914947 activeStreamInfo.currentDuration = getDuration (avFrame);
915- AVFrameStream avFrameStream;
916- avFrameStream.streamIndex = frameStreamIndex;
917- avFrameStream.avFrame = std::move (avFrame);
918- return avFrameStream;
948+
949+ return AVFrameStream (std::move (avFrame), frameStreamIndex);
919950}
920951
921952VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
@@ -1027,7 +1058,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
10271058 } else if (
10281059 streamInfo.colorConversionLibrary ==
10291060 ColorConversionLibrary::FILTERGRAPH) {
1030- if (!streamInfo.filterState .filterGraph ||
1061+ if (!streamInfo.filterGraphContext .filterGraph ||
10311062 streamInfo.prevFrameContext != frameContext) {
10321063 createFilterGraph (streamInfo, expectedOutputHeight, expectedOutputWidth);
10331064 streamInfo.prevFrameContext = frameContext;
@@ -1079,8 +1110,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
10791110 }
10801111
10811112 setCursorPtsInSeconds (seconds);
1082- AVFrameStream avFrameStream = getAVFrameUsingFilterFunction (
1083- [seconds, this ](int frameStreamIndex, AVFrame* avFrame) {
1113+ AVFrameStream avFrameStream =
1114+ decodeAVFrame ( [seconds, this ](int frameStreamIndex, AVFrame* avFrame) {
10841115 StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
10851116 double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
10861117 double frameEndTime = ptsToSeconds (
@@ -1480,8 +1511,8 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
14801511
14811512VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal (
14821513 std::optional<torch::Tensor> preAllocatedOutputTensor) {
1483- AVFrameStream avFrameStream = getAVFrameUsingFilterFunction (
1484- [this ](int frameStreamIndex, AVFrame* avFrame) {
1514+ AVFrameStream avFrameStream =
1515+ decodeAVFrame ( [this ](int frameStreamIndex, AVFrame* avFrame) {
14851516 StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
14861517 return avFrame->pts >= activeStreamInfo.discardFramesBeforePts ;
14871518 });
@@ -1585,16 +1616,17 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
15851616torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph (
15861617 int streamIndex,
15871618 const AVFrame* avFrame) {
1588- FilterState& filterState = streamInfos_[streamIndex].filterState ;
1619+ FilterGraphContext& filterGraphContext =
1620+ streamInfos_[streamIndex].filterGraphContext ;
15891621 int ffmpegStatus =
1590- av_buffersrc_write_frame (filterState .sourceContext , avFrame);
1622+ av_buffersrc_write_frame (filterGraphContext .sourceContext , avFrame);
15911623 if (ffmpegStatus < AVSUCCESS) {
15921624 throw std::runtime_error (" Failed to add frame to buffer source context" );
15931625 }
15941626
15951627 UniqueAVFrame filteredAVFrame (av_frame_alloc ());
1596- ffmpegStatus =
1597- av_buffersink_get_frame (filterState .sinkContext , filteredAVFrame.get ());
1628+ ffmpegStatus = av_buffersink_get_frame (
1629+ filterGraphContext .sinkContext , filteredAVFrame.get ());
15981630 TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
15991631
16001632 auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
0 commit comments