Skip to content

Commit c8e70cd

Browse files
author
pytorchbot
committed
2025-01-28 nightly release (3253c8f)
1 parent 0221809 commit c8e70cd

File tree

2 files changed

+266
-195
lines changed

2 files changed

+266
-195
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,12 @@ void VideoDecoder::createFilterGraph(
337337
StreamInfo& streamInfo,
338338
int expectedOutputHeight,
339339
int expectedOutputWidth) {
340-
FilterState& filterState = streamInfo.filterState;
341-
filterState.filterGraph.reset(avfilter_graph_alloc());
342-
TORCH_CHECK(filterState.filterGraph.get() != nullptr);
340+
FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext;
341+
filterGraphContext.filterGraph.reset(avfilter_graph_alloc());
342+
TORCH_CHECK(filterGraphContext.filterGraph.get() != nullptr);
343343

344344
if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) {
345-
filterState.filterGraph->nb_threads =
345+
filterGraphContext.filterGraph->nb_threads =
346346
streamInfo.videoStreamOptions.ffmpegThreadCount.value();
347347
}
348348

@@ -360,25 +360,25 @@ void VideoDecoder::createFilterGraph(
360360
<< codecContext->sample_aspect_ratio.den;
361361

362362
int ffmpegStatus = avfilter_graph_create_filter(
363-
&filterState.sourceContext,
363+
&filterGraphContext.sourceContext,
364364
buffersrc,
365365
"in",
366366
filterArgs.str().c_str(),
367367
nullptr,
368-
filterState.filterGraph.get());
368+
filterGraphContext.filterGraph.get());
369369
if (ffmpegStatus < 0) {
370370
throw std::runtime_error(
371371
std::string("Failed to create filter graph: ") + filterArgs.str() +
372372
": " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
373373
}
374374

375375
ffmpegStatus = avfilter_graph_create_filter(
376-
&filterState.sinkContext,
376+
&filterGraphContext.sinkContext,
377377
buffersink,
378378
"out",
379379
nullptr,
380380
nullptr,
381-
filterState.filterGraph.get());
381+
filterGraphContext.filterGraph.get());
382382
if (ffmpegStatus < 0) {
383383
throw std::runtime_error(
384384
"Failed to create filter graph: " +
@@ -388,7 +388,7 @@ void VideoDecoder::createFilterGraph(
388388
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
389389

390390
ffmpegStatus = av_opt_set_int_list(
391-
filterState.sinkContext,
391+
filterGraphContext.sinkContext,
392392
"pix_fmts",
393393
pix_fmts,
394394
AV_PIX_FMT_NONE,
@@ -403,11 +403,11 @@ void VideoDecoder::createFilterGraph(
403403
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
404404

405405
outputs->name = av_strdup("in");
406-
outputs->filter_ctx = filterState.sourceContext;
406+
outputs->filter_ctx = filterGraphContext.sourceContext;
407407
outputs->pad_idx = 0;
408408
outputs->next = nullptr;
409409
inputs->name = av_strdup("out");
410-
inputs->filter_ctx = filterState.sinkContext;
410+
inputs->filter_ctx = filterGraphContext.sinkContext;
411411
inputs->pad_idx = 0;
412412
inputs->next = nullptr;
413413

@@ -418,7 +418,7 @@ void VideoDecoder::createFilterGraph(
418418
AVFilterInOut* outputsTmp = outputs.release();
419419
AVFilterInOut* inputsTmp = inputs.release();
420420
ffmpegStatus = avfilter_graph_parse_ptr(
421-
filterState.filterGraph.get(),
421+
filterGraphContext.filterGraph.get(),
422422
description.str().c_str(),
423423
&inputsTmp,
424424
&outputsTmp,
@@ -431,7 +431,8 @@ void VideoDecoder::createFilterGraph(
431431
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
432432
}
433433

434-
ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr);
434+
ffmpegStatus =
435+
avfilter_graph_config(filterGraphContext.filterGraph.get(), nullptr);
435436
if (ffmpegStatus < 0) {
436437
throw std::runtime_error(
437438
"Failed to configure filter graph: " +
@@ -803,16 +804,20 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
803804
}
804805
}
805806

806-
VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
807+
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
807808
std::function<bool(int, AVFrame*)> filterFunction) {
808809
if (activeStreamIndices_.size() == 0) {
809810
throw std::runtime_error("No active streams configured.");
810811
}
812+
811813
resetDecodeStats();
814+
815+
// Seek if needed.
812816
if (desiredPtsSeconds_.has_value()) {
813817
maybeSeekToBeforeDesiredPts();
814818
desiredPtsSeconds_ = std::nullopt;
815819
}
820+
816821
// Need to get the next frame or error from PopFrame.
817822
UniqueAVFrame avFrame(av_frame_alloc());
818823
AutoAVPacket autoAVPacket;
@@ -822,42 +827,58 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
822827
while (true) {
823828
frameStreamIndex = -1;
824829
bool gotPermanentErrorOnAnyActiveStream = false;
830+
831+
// Get a frame on an active stream. Note that we don't know ahead of time
832+
// which streams have frames to receive, so we linearly try the active
833+
// streams.
825834
for (int streamIndex : activeStreamIndices_) {
826835
StreamInfo& streamInfo = streamInfos_[streamIndex];
827836
ffmpegStatus =
828837
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
829-
bool gotNonRetriableError =
830-
ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN);
831-
if (gotNonRetriableError) {
838+
839+
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
832840
gotPermanentErrorOnAnyActiveStream = true;
833841
break;
834842
}
843+
835844
if (ffmpegStatus == AVSUCCESS) {
836845
frameStreamIndex = streamIndex;
837846
break;
838847
}
839848
}
849+
840850
if (gotPermanentErrorOnAnyActiveStream) {
841851
break;
842852
}
853+
843854
decodeStats_.numFramesReceivedByDecoder++;
844-
bool gotNeededFrame = ffmpegStatus == AVSUCCESS &&
845-
filterFunction(frameStreamIndex, avFrame.get());
846-
if (gotNeededFrame) {
855+
856+
// Is this the kind of frame we're looking for?
857+
if (ffmpegStatus == AVSUCCESS &&
858+
filterFunction(frameStreamIndex, avFrame.get())) {
859+
// Yes, this is the frame we'll return; break out of the decoding loop.
847860
break;
848861
} else if (ffmpegStatus == AVSUCCESS) {
849-
// No need to send more packets here as the decoder may have frames in
850-
// its buffer.
862+
// No, but we received a valid frame - just not the kind we're looking
863+
// for. The logic below will read packets and send them to the decoder.
864+
// But since we did just receive a frame, we should skip reading more
865+
// packets and sending them to the decoder and just try to receive more
866+
// frames from the decoder.
851867
continue;
852868
}
869+
853870
if (reachedEOF) {
854871
// We don't have any more packets to send to the decoder. So keep on
855872
// pulling frames from its internal buffers.
856873
continue;
857874
}
875+
876+
// We still haven't found the frame we're looking for. So let's read more
877+
// packets and send them to the decoder.
858878
ReferenceAVPacket packet(autoAVPacket);
859879
ffmpegStatus = av_read_frame(formatContext_.get(), packet.get());
860880
decodeStats_.numPacketsRead++;
881+
861882
if (ffmpegStatus == AVERROR_EOF) {
862883
// End of file reached. We must drain all codecs by sending a nullptr
863884
// packet.
@@ -872,27 +893,38 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
872893
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
873894
}
874895
}
896+
897+
// We've reached the end of file so we can't read any more packets from
898+
// it, but the decoder may still have frames to read in its buffer.
899+
// Continue iterating to try reading frames.
875900
reachedEOF = true;
876901
continue;
877902
}
903+
878904
if (ffmpegStatus < AVSUCCESS) {
879905
throw std::runtime_error(
880906
"Could not read frame from input file: " +
881907
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
882908
}
909+
883910
if (activeStreamIndices_.count(packet->stream_index) == 0) {
884911
// This packet is not for any of the active streams.
885912
continue;
886913
}
914+
915+
// We got a valid packet. Send it to the decoder, and we'll receive it in
916+
// the next iteration.
887917
ffmpegStatus = avcodec_send_packet(
888918
streamInfos_[packet->stream_index].codecContext.get(), packet.get());
889919
if (ffmpegStatus < AVSUCCESS) {
890920
throw std::runtime_error(
891921
"Could not push packet to decoder: " +
892922
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
893923
}
924+
894925
decodeStats_.numPacketsSentToDecoder++;
895926
}
927+
896928
if (ffmpegStatus < AVSUCCESS) {
897929
if (reachedEOF || ffmpegStatus == AVERROR_EOF) {
898930
throw VideoDecoder::EndOfFileException(
@@ -903,6 +935,7 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
903935
"Could not receive frame from decoder: " +
904936
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
905937
}
938+
906939
// Note that we don't flush the decoder when we reach EOF (even though that's
907940
// mentioned in https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html).
908941
// This is because we may have packets internally in the decoder that we
@@ -912,10 +945,8 @@ VideoDecoder::AVFrameStream VideoDecoder::getAVFrameUsingFilterFunction(
912945
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
913946
activeStreamInfo.currentPts = avFrame->pts;
914947
activeStreamInfo.currentDuration = getDuration(avFrame);
915-
AVFrameStream avFrameStream;
916-
avFrameStream.streamIndex = frameStreamIndex;
917-
avFrameStream.avFrame = std::move(avFrame);
918-
return avFrameStream;
948+
949+
return AVFrameStream(std::move(avFrame), frameStreamIndex);
919950
}
920951

921952
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
@@ -1027,7 +1058,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
10271058
} else if (
10281059
streamInfo.colorConversionLibrary ==
10291060
ColorConversionLibrary::FILTERGRAPH) {
1030-
if (!streamInfo.filterState.filterGraph ||
1061+
if (!streamInfo.filterGraphContext.filterGraph ||
10311062
streamInfo.prevFrameContext != frameContext) {
10321063
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
10331064
streamInfo.prevFrameContext = frameContext;
@@ -1079,8 +1110,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
10791110
}
10801111

10811112
setCursorPtsInSeconds(seconds);
1082-
AVFrameStream avFrameStream = getAVFrameUsingFilterFunction(
1083-
[seconds, this](int frameStreamIndex, AVFrame* avFrame) {
1113+
AVFrameStream avFrameStream =
1114+
decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) {
10841115
StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
10851116
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
10861117
double frameEndTime = ptsToSeconds(
@@ -1480,8 +1511,8 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
14801511

14811512
VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal(
14821513
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1483-
AVFrameStream avFrameStream = getAVFrameUsingFilterFunction(
1484-
[this](int frameStreamIndex, AVFrame* avFrame) {
1514+
AVFrameStream avFrameStream =
1515+
decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) {
14851516
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
14861517
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
14871518
});
@@ -1585,16 +1616,17 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
15851616
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
15861617
int streamIndex,
15871618
const AVFrame* avFrame) {
1588-
FilterState& filterState = streamInfos_[streamIndex].filterState;
1619+
FilterGraphContext& filterGraphContext =
1620+
streamInfos_[streamIndex].filterGraphContext;
15891621
int ffmpegStatus =
1590-
av_buffersrc_write_frame(filterState.sourceContext, avFrame);
1622+
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame);
15911623
if (ffmpegStatus < AVSUCCESS) {
15921624
throw std::runtime_error("Failed to add frame to buffer source context");
15931625
}
15941626

15951627
UniqueAVFrame filteredAVFrame(av_frame_alloc());
1596-
ffmpegStatus =
1597-
av_buffersink_get_frame(filterState.sinkContext, filteredAVFrame.get());
1628+
ffmpegStatus = av_buffersink_get_frame(
1629+
filterGraphContext.sinkContext, filteredAVFrame.get());
15981630
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
15991631

16001632
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());

0 commit comments

Comments
 (0)