Skip to content

Commit cfadb8c

Browse files
committed
"Remove multi-stream related code"
1 parent 746401c commit cfadb8c

File tree

2 files changed

+47
-96
lines changed

2 files changed

+47
-96
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 45 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -449,11 +449,7 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
449449
void VideoDecoder::addVideoStreamDecoder(
450450
int preferredStreamIndex,
451451
const VideoStreamOptions& videoStreamOptions) {
452-
if (activeStreamIndices_.count(preferredStreamIndex) > 0) {
453-
throw std::invalid_argument(
454-
"Stream with index " + std::to_string(preferredStreamIndex) +
455-
" is already active.");
456-
}
452+
TORCH_CHECK(activeStreamIndex_ == -1, "Can only add one single stream.");
457453
TORCH_CHECK(formatContext_.get() != nullptr);
458454

459455
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
@@ -520,7 +516,7 @@ void VideoDecoder::addVideoStreamDecoder(
520516
}
521517

522518
codecContext->time_base = streamInfo.stream->time_base;
523-
activeStreamIndices_.insert(streamIndex);
519+
activeStreamIndex_ = streamIndex;
524520
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
525521
streamInfo.videoStreamOptions = videoStreamOptions;
526522

@@ -740,53 +736,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
740736
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
741737
// the comment of canWeAvoidSeeking() for details.
742738
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
743-
if (activeStreamIndices_.size() == 0) {
739+
if (activeStreamIndex_ == -1) {
744740
return;
745741
}
746-
for (int streamIndex : activeStreamIndices_) {
747-
StreamInfo& streamInfo = streamInfos_[streamIndex];
748-
// clang-format off: clang format clashes
749-
streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
750-
// clang-format on
751-
}
742+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
743+
streamInfo.discardFramesBeforePts =
744+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
752745

753746
decodeStats_.numSeeksAttempted++;
754-
// See comment for canWeAvoidSeeking() for details on why this optimization
755-
// works.
756-
bool mustSeek = false;
757-
for (int streamIndex : activeStreamIndices_) {
758-
StreamInfo& streamInfo = streamInfos_[streamIndex];
759-
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
760-
if (!canWeAvoidSeekingForStream(
761-
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
762-
mustSeek = true;
763-
break;
764-
}
765-
}
766-
if (!mustSeek) {
747+
748+
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
749+
if (canWeAvoidSeekingForStream(
750+
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
767751
decodeStats_.numSeeksSkipped++;
768752
return;
769753
}
770-
int firstActiveStreamIndex = *activeStreamIndices_.begin();
771-
const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex];
772754
int64_t desiredPts =
773-
secondsToClosestPts(*desiredPtsSeconds_, firstStreamInfo.timeBase);
755+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
774756

775757
// For some encodings like H265, FFMPEG sometimes seeks past the point we
776758
// set as the max_ts. So we use our own index to give it the exact pts of
777759
// the key frame that we want to seek to.
778760
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
779761
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
780-
if (!firstStreamInfo.keyFrames.empty()) {
762+
if (!streamInfo.keyFrames.empty()) {
781763
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
782-
firstStreamInfo.keyFrames, desiredPts);
764+
streamInfo.keyFrames, desiredPts);
783765
desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
784-
desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts;
766+
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
785767
}
786768

787769
int ffmepgStatus = avformat_seek_file(
788770
formatContext_.get(),
789-
firstStreamInfo.streamIndex,
771+
streamInfo.streamIndex,
790772
INT64_MIN,
791773
desiredPts,
792774
desiredPts,
@@ -797,15 +779,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
797779
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
798780
}
799781
decodeStats_.numFlushes++;
800-
for (int streamIndex : activeStreamIndices_) {
801-
StreamInfo& streamInfo = streamInfos_[streamIndex];
802-
avcodec_flush_buffers(streamInfo.codecContext.get());
803-
}
782+
avcodec_flush_buffers(streamInfo.codecContext.get());
804783
}
805784

806785
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
807-
std::function<bool(int, AVFrame*)> filterFunction) {
808-
if (activeStreamIndices_.size() == 0) {
786+
std::function<bool(AVFrame*)> filterFunction) {
787+
if (activeStreamIndex_ == -1) {
809788
throw std::runtime_error("No active streams configured.");
810789
}
811790

@@ -817,44 +796,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
817796
desiredPtsSeconds_ = std::nullopt;
818797
}
819798

799+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
800+
820801
// Need to get the next frame or error from PopFrame.
821802
UniqueAVFrame avFrame(av_frame_alloc());
822803
AutoAVPacket autoAVPacket;
823804
int ffmpegStatus = AVSUCCESS;
824805
bool reachedEOF = false;
825-
int frameStreamIndex = -1;
826806
while (true) {
827-
frameStreamIndex = -1;
828-
bool gotPermanentErrorOnAnyActiveStream = false;
829-
830-
// Get a frame on an active stream. Note that we don't know ahead of time
831-
// which streams have frames to receive, so we linearly try the active
832-
// streams.
833-
for (int streamIndex : activeStreamIndices_) {
834-
StreamInfo& streamInfo = streamInfos_[streamIndex];
835-
ffmpegStatus =
836-
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
837-
838-
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
839-
gotPermanentErrorOnAnyActiveStream = true;
840-
break;
841-
}
807+
ffmpegStatus =
808+
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
842809

843-
if (ffmpegStatus == AVSUCCESS) {
844-
frameStreamIndex = streamIndex;
845-
break;
846-
}
847-
}
848-
849-
if (gotPermanentErrorOnAnyActiveStream) {
810+
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
811+
// Non-retriable error
850812
break;
851813
}
852814

853815
decodeStats_.numFramesReceivedByDecoder++;
854-
855816
// Is this the kind of frame we're looking for?
856-
if (ffmpegStatus == AVSUCCESS &&
857-
filterFunction(frameStreamIndex, avFrame.get())) {
817+
if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) {
858818
// Yes, this is the frame we'll return; break out of the decoding loop.
859819
break;
860820
} else if (ffmpegStatus == AVSUCCESS) {
@@ -879,18 +839,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
879839
decodeStats_.numPacketsRead++;
880840

881841
if (ffmpegStatus == AVERROR_EOF) {
882-
// End of file reached. We must drain all codecs by sending a nullptr
842+
// End of file reached. We must drain the codec by sending a nullptr
883843
// packet.
884-
for (int streamIndex : activeStreamIndices_) {
885-
StreamInfo& streamInfo = streamInfos_[streamIndex];
886-
ffmpegStatus = avcodec_send_packet(
887-
streamInfo.codecContext.get(),
888-
/*avpkt=*/nullptr);
889-
if (ffmpegStatus < AVSUCCESS) {
890-
throw std::runtime_error(
891-
"Could not flush decoder: " +
892-
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
893-
}
844+
ffmpegStatus = avcodec_send_packet(
845+
streamInfo.codecContext.get(),
846+
/*avpkt=*/nullptr);
847+
if (ffmpegStatus < AVSUCCESS) {
848+
throw std::runtime_error(
849+
"Could not flush decoder: " +
850+
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
894851
}
895852

896853
// We've reached the end of file so we can't read any more packets from
@@ -906,15 +863,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
906863
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
907864
}
908865

909-
if (activeStreamIndices_.count(packet->stream_index) == 0) {
910-
// This packet is not for any of the active streams.
866+
if (packet->stream_index != activeStreamIndex_) {
911867
continue;
912868
}
913869

914870
// We got a valid packet. Send it to the decoder, and we'll receive it in
915871
// the next iteration.
916-
ffmpegStatus = avcodec_send_packet(
917-
streamInfos_[packet->stream_index].codecContext.get(), packet.get());
872+
ffmpegStatus =
873+
avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
918874
if (ffmpegStatus < AVSUCCESS) {
919875
throw std::runtime_error(
920876
"Could not push packet to decoder: " +
@@ -941,11 +897,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
941897
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
942898
// av_receive_frame() or the user will have seeked to a different location in
943899
// the file and that will flush the decoder.
944-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
945-
activeStreamInfo.currentPts = avFrame->pts;
946-
activeStreamInfo.currentDuration = getDuration(avFrame);
900+
streamInfo.currentPts = avFrame->pts;
901+
streamInfo.currentDuration = getDuration(avFrame);
947902

948-
return AVFrameStream(std::move(avFrame), frameStreamIndex);
903+
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
949904
}
950905

951906
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
@@ -1110,8 +1065,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
11101065

11111066
setCursorPtsInSeconds(seconds);
11121067
AVFrameStream avFrameStream =
1113-
decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) {
1114-
StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
1068+
decodeAVFrame([seconds, this](AVFrame* avFrame) {
1069+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
11151070
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
11161071
double frameEndTime = ptsToSeconds(
11171072
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
@@ -1510,11 +1465,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
15101465

15111466
VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal(
15121467
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1513-
AVFrameStream avFrameStream =
1514-
decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) {
1515-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1516-
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1517-
});
1468+
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
1469+
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1470+
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1471+
});
15181472
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
15191473
}
15201474

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,7 @@ class VideoDecoder {
404404
const enum AVColorSpace colorspace);
405405

406406
void maybeSeekToBeforeDesiredPts();
407-
AVFrameStream decodeAVFrame(
408-
std::function<bool(int, AVFrame*)> filterFunction);
407+
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
409408
// Once we create a decoder can update the metadata with the codec context.
410409
// For example, for video streams, we can add the height and width of the
411410
// decoded stream.
@@ -435,9 +434,7 @@ class VideoDecoder {
435434
ContainerMetadata containerMetadata_;
436435
UniqueAVFormatContext formatContext_;
437436
std::map<int, StreamInfo> streamInfos_;
438-
// Stores the stream indices of the active streams, i.e. the streams we are
439-
// decoding and returning to the user.
440-
std::set<int> activeStreamIndices_;
437+
int activeStreamIndex_ = -1;
441438
// Set when the user wants to seek and stores the desired pts that the user
442439
// wants to seek to.
443440
std::optional<double> desiredPtsSeconds_;

0 commit comments

Comments
 (0)