Skip to content

Commit ae593d1

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into remove_method
2 parents 6515963 + 28b0de0 commit ae593d1

File tree

9 files changed

+175
-134
lines changed

9 files changed

+175
-134
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 75 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -318,21 +318,6 @@ void VideoDecoder::initializeDecoder() {
318318
initialized_ = true;
319319
}
320320

321-
std::unique_ptr<VideoDecoder> VideoDecoder::createFromFilePath(
322-
const std::string& videoFilePath,
323-
SeekMode seekMode) {
324-
return std::unique_ptr<VideoDecoder>(
325-
new VideoDecoder(videoFilePath, seekMode));
326-
}
327-
328-
std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
329-
const void* buffer,
330-
size_t length,
331-
SeekMode seekMode) {
332-
return std::unique_ptr<VideoDecoder>(
333-
new VideoDecoder(buffer, length, seekMode));
334-
}
335-
336321
void VideoDecoder::createFilterGraph(
337322
StreamInfo& streamInfo,
338323
int expectedOutputHeight,
@@ -450,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
450435
void VideoDecoder::addVideoStreamDecoder(
451436
int preferredStreamIndex,
452437
const VideoStreamOptions& videoStreamOptions) {
453-
if (activeStreamIndices_.count(preferredStreamIndex) > 0) {
454-
throw std::invalid_argument(
455-
"Stream with index " + std::to_string(preferredStreamIndex) +
456-
" is already active.");
457-
}
438+
TORCH_CHECK(
439+
activeStreamIndex_ == NO_ACTIVE_STREAM,
440+
"Can only add one single stream.");
458441
TORCH_CHECK(formatContext_.get() != nullptr);
459442

460443
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
@@ -521,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder(
521504
}
522505

523506
codecContext->time_base = streamInfo.stream->time_base;
524-
activeStreamIndices_.insert(streamIndex);
507+
activeStreamIndex_ = streamIndex;
525508
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
526509
streamInfo.videoStreamOptions = videoStreamOptions;
527510

@@ -553,6 +536,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
553536
return containerMetadata_;
554537
}
555538

539+
torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) {
540+
validateUserProvidedStreamIndex(streamIndex);
541+
validateScannedAllStreams("getKeyFrameIndices");
542+
543+
const std::vector<FrameInfo>& keyFrames = streamInfos_[streamIndex].keyFrames;
544+
torch::Tensor keyFrameIndices =
545+
torch::empty({static_cast<int64_t>(keyFrames.size())}, {torch::kInt64});
546+
for (size_t i = 0; i < keyFrames.size(); ++i) {
547+
keyFrameIndices[i] = keyFrames[i].frameIndex;
548+
}
549+
550+
return keyFrameIndices;
551+
}
552+
556553
int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex(
557554
const std::vector<VideoDecoder::FrameInfo>& keyFrames,
558555
int64_t pts) const {
@@ -661,7 +658,21 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
661658
return frameInfo1.pts < frameInfo2.pts;
662659
});
663660

661+
size_t keyIndex = 0;
664662
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
663+
streamInfo.allFrames[i].frameIndex = i;
664+
665+
// For correctly encoded files, we shouldn't need to ensure that keyIndex
666+
// is less than the number of key frames. That is, the relationship
667+
// between the frames in allFrames and keyFrames should be such that
668+
// keyIndex is always a valid index into keyFrames. But we're being
669+
// defensive in case we encounter incorrectly encoded files.
670+
if (keyIndex < streamInfo.keyFrames.size() &&
671+
streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) {
672+
streamInfo.keyFrames[keyIndex].frameIndex = i;
673+
++keyIndex;
674+
}
675+
665676
if (i + 1 < streamInfo.allFrames.size()) {
666677
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
667678
}
@@ -735,53 +746,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
735746
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
736747
// the comment of canWeAvoidSeeking() for details.
737748
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
738-
if (activeStreamIndices_.size() == 0) {
749+
if (activeStreamIndex_ == NO_ACTIVE_STREAM) {
739750
return;
740751
}
741-
for (int streamIndex : activeStreamIndices_) {
742-
StreamInfo& streamInfo = streamInfos_[streamIndex];
743-
// clang-format off: clang format clashes
744-
streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
745-
// clang-format on
746-
}
752+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
753+
streamInfo.discardFramesBeforePts =
754+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
747755

748756
decodeStats_.numSeeksAttempted++;
749-
// See comment for canWeAvoidSeeking() for details on why this optimization
750-
// works.
751-
bool mustSeek = false;
752-
for (int streamIndex : activeStreamIndices_) {
753-
StreamInfo& streamInfo = streamInfos_[streamIndex];
754-
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
755-
if (!canWeAvoidSeekingForStream(
756-
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
757-
mustSeek = true;
758-
break;
759-
}
760-
}
761-
if (!mustSeek) {
757+
758+
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
759+
if (canWeAvoidSeekingForStream(
760+
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
762761
decodeStats_.numSeeksSkipped++;
763762
return;
764763
}
765-
int firstActiveStreamIndex = *activeStreamIndices_.begin();
766-
const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex];
767764
int64_t desiredPts =
768-
secondsToClosestPts(*desiredPtsSeconds_, firstStreamInfo.timeBase);
765+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
769766

770767
// For some encodings like H265, FFMPEG sometimes seeks past the point we
771768
// set as the max_ts. So we use our own index to give it the exact pts of
772769
// the key frame that we want to seek to.
773770
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
774771
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
775-
if (!firstStreamInfo.keyFrames.empty()) {
772+
if (!streamInfo.keyFrames.empty()) {
776773
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
777-
firstStreamInfo.keyFrames, desiredPts);
774+
streamInfo.keyFrames, desiredPts);
778775
desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
779-
desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts;
776+
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
780777
}
781778

782779
int ffmepgStatus = avformat_seek_file(
783780
formatContext_.get(),
784-
firstStreamInfo.streamIndex,
781+
streamInfo.streamIndex,
785782
INT64_MIN,
786783
desiredPts,
787784
desiredPts,
@@ -792,15 +789,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
792789
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
793790
}
794791
decodeStats_.numFlushes++;
795-
for (int streamIndex : activeStreamIndices_) {
796-
StreamInfo& streamInfo = streamInfos_[streamIndex];
797-
avcodec_flush_buffers(streamInfo.codecContext.get());
798-
}
792+
avcodec_flush_buffers(streamInfo.codecContext.get());
799793
}
800794

801795
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
802-
std::function<bool(int, AVFrame*)> filterFunction) {
803-
if (activeStreamIndices_.size() == 0) {
796+
std::function<bool(AVFrame*)> filterFunction) {
797+
if (activeStreamIndex_ == NO_ACTIVE_STREAM) {
804798
throw std::runtime_error("No active streams configured.");
805799
}
806800

@@ -812,44 +806,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
812806
desiredPtsSeconds_ = std::nullopt;
813807
}
814808

809+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
810+
815811
// Need to get the next frame or error from PopFrame.
816812
UniqueAVFrame avFrame(av_frame_alloc());
817813
AutoAVPacket autoAVPacket;
818814
int ffmpegStatus = AVSUCCESS;
819815
bool reachedEOF = false;
820-
int frameStreamIndex = -1;
821816
while (true) {
822-
frameStreamIndex = -1;
823-
bool gotPermanentErrorOnAnyActiveStream = false;
824-
825-
// Get a frame on an active stream. Note that we don't know ahead of time
826-
// which streams have frames to receive, so we linearly try the active
827-
// streams.
828-
for (int streamIndex : activeStreamIndices_) {
829-
StreamInfo& streamInfo = streamInfos_[streamIndex];
830-
ffmpegStatus =
831-
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
832-
833-
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
834-
gotPermanentErrorOnAnyActiveStream = true;
835-
break;
836-
}
817+
ffmpegStatus =
818+
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
837819

838-
if (ffmpegStatus == AVSUCCESS) {
839-
frameStreamIndex = streamIndex;
840-
break;
841-
}
842-
}
843-
844-
if (gotPermanentErrorOnAnyActiveStream) {
820+
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
821+
// Non-retriable error
845822
break;
846823
}
847824

848825
decodeStats_.numFramesReceivedByDecoder++;
849-
850826
// Is this the kind of frame we're looking for?
851-
if (ffmpegStatus == AVSUCCESS &&
852-
filterFunction(frameStreamIndex, avFrame.get())) {
827+
if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) {
853828
// Yes, this is the frame we'll return; break out of the decoding loop.
854829
break;
855830
} else if (ffmpegStatus == AVSUCCESS) {
@@ -874,18 +849,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
874849
decodeStats_.numPacketsRead++;
875850

876851
if (ffmpegStatus == AVERROR_EOF) {
877-
// End of file reached. We must drain all codecs by sending a nullptr
852+
// End of file reached. We must drain the codec by sending a nullptr
878853
// packet.
879-
for (int streamIndex : activeStreamIndices_) {
880-
StreamInfo& streamInfo = streamInfos_[streamIndex];
881-
ffmpegStatus = avcodec_send_packet(
882-
streamInfo.codecContext.get(),
883-
/*avpkt=*/nullptr);
884-
if (ffmpegStatus < AVSUCCESS) {
885-
throw std::runtime_error(
886-
"Could not flush decoder: " +
887-
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
888-
}
854+
ffmpegStatus = avcodec_send_packet(
855+
streamInfo.codecContext.get(),
856+
/*avpkt=*/nullptr);
857+
if (ffmpegStatus < AVSUCCESS) {
858+
throw std::runtime_error(
859+
"Could not flush decoder: " +
860+
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
889861
}
890862

891863
// We've reached the end of file so we can't read any more packets from
@@ -901,15 +873,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
901873
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
902874
}
903875

904-
if (activeStreamIndices_.count(packet->stream_index) == 0) {
905-
// This packet is not for any of the active streams.
876+
if (packet->stream_index != activeStreamIndex_) {
906877
continue;
907878
}
908879

909880
// We got a valid packet. Send it to the decoder, and we'll receive it in
910881
// the next iteration.
911-
ffmpegStatus = avcodec_send_packet(
912-
streamInfos_[packet->stream_index].codecContext.get(), packet.get());
882+
ffmpegStatus =
883+
avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
913884
if (ffmpegStatus < AVSUCCESS) {
914885
throw std::runtime_error(
915886
"Could not push packet to decoder: " +
@@ -936,11 +907,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
936907
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
937908
// av_receive_frame() or the user will have seeked to a different location in
938909
// the file and that will flush the decoder.
939-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
940-
activeStreamInfo.currentPts = avFrame->pts;
941-
activeStreamInfo.currentDuration = getDuration(avFrame);
910+
streamInfo.currentPts = avFrame->pts;
911+
streamInfo.currentDuration = getDuration(avFrame);
942912

943-
return AVFrameStream(std::move(avFrame), frameStreamIndex);
913+
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
944914
}
945915

946916
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
@@ -1105,8 +1075,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
11051075

11061076
setCursorPtsInSeconds(seconds);
11071077
AVFrameStream avFrameStream =
1108-
decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) {
1109-
StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
1078+
decodeAVFrame([seconds, this](AVFrame* avFrame) {
1079+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
11101080
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
11111081
double frameEndTime = ptsToSeconds(
11121082
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
@@ -1505,11 +1475,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
15051475

15061476
VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal(
15071477
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1508-
AVFrameStream avFrameStream =
1509-
decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) {
1510-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1511-
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1512-
});
1478+
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
1479+
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1480+
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1481+
});
15131482
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
15141483
}
15151484

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,14 @@ class VideoDecoder {
2929

3030
enum class SeekMode { exact, approximate };
3131

32-
explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode);
33-
explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode);
34-
3532
// Creates a VideoDecoder from the video at videoFilePath.
36-
static std::unique_ptr<VideoDecoder> createFromFilePath(
33+
explicit VideoDecoder(
3734
const std::string& videoFilePath,
3835
SeekMode seekMode = SeekMode::exact);
3936

4037
// Creates a VideoDecoder from a given buffer. Note that the buffer is not
4138
// owned by the VideoDecoder.
42-
static std::unique_ptr<VideoDecoder> createFromBuffer(
39+
explicit VideoDecoder(
4340
const void* buffer,
4441
size_t length,
4542
SeekMode seekMode = SeekMode::exact);
@@ -100,6 +97,10 @@ class VideoDecoder {
10097
// Returns the metadata for the container.
10198
ContainerMetadata getContainerMetadata() const;
10299

100+
// Returns the key frame indices as a tensor. The tensor is 1D and contains
101+
// int64 values, where each value is the frame index for a key frame.
102+
torch::Tensor getKeyFrameIndices(int streamIndex);
103+
103104
// --------------------------------------------------------------------------
104105
// ADDING STREAMS API
105106
// --------------------------------------------------------------------------
@@ -287,12 +288,19 @@ class VideoDecoder {
287288

288289
struct FrameInfo {
289290
int64_t pts = 0;
290-
// The value of this default is important: the last frame's nextPts will be
291-
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
292-
// structs with *increasing* nextPts values. That's a necessary condition
293-
// for the binary searches on those values to work properly (as typically
294-
// done during pts -> index conversions.)
291+
292+
// The value of the nextPts default is important: the last frame's nextPts
293+
// will be INT64_MAX, which ensures that the allFrames vec contains
294+
// FrameInfo structs with *increasing* nextPts values. That's a necessary
295+
// condition for the binary searches on those values to work properly (as
296+
// typically done during pts -> index conversions).
295297
int64_t nextPts = INT64_MAX;
298+
299+
// Note that frameIndex is ALWAYS the index into all of the frames in that
300+
// stream, even when the FrameInfo is part of the key frame index. Given a
301+
// FrameInfo for a key frame, the frameIndex allows us to know which frame
302+
// that is in the stream.
303+
int64_t frameIndex = 0;
296304
};
297305

298306
struct FilterGraphContext {
@@ -364,8 +372,7 @@ class VideoDecoder {
364372

365373
void maybeSeekToBeforeDesiredPts();
366374

367-
AVFrameStream decodeAVFrame(
368-
std::function<bool(int, AVFrame*)> filterFunction);
375+
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
369376

370377
FrameOutput getNextFrameNoDemuxInternal(
371378
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
@@ -469,9 +476,8 @@ class VideoDecoder {
469476
ContainerMetadata containerMetadata_;
470477
UniqueAVFormatContext formatContext_;
471478
std::map<int, StreamInfo> streamInfos_;
472-
// Stores the stream indices of the active streams, i.e. the streams we are
473-
// decoding and returning to the user.
474-
std::set<int> activeStreamIndices_;
479+
const int NO_ACTIVE_STREAM = -2;
480+
int activeStreamIndex_ = NO_ACTIVE_STREAM;
475481
// Set when the user wants to seek and stores the desired pts that the user
476482
// wants to seek to.
477483
std::optional<double> desiredPtsSeconds_;

0 commit comments

Comments
 (0)