Skip to content

Commit f7db2c9

Browse files
author
pytorchbot
committed
2025-01-30 nightly release (28b0de0)
1 parent 7cc40f1 commit f7db2c9

File tree

8 files changed

+163
-101
lines changed

8 files changed

+163
-101
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 75 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -435,11 +435,9 @@ int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
435435
void VideoDecoder::addVideoStreamDecoder(
436436
int preferredStreamIndex,
437437
const VideoStreamOptions& videoStreamOptions) {
438-
if (activeStreamIndices_.count(preferredStreamIndex) > 0) {
439-
throw std::invalid_argument(
440-
"Stream with index " + std::to_string(preferredStreamIndex) +
441-
" is already active.");
442-
}
438+
TORCH_CHECK(
439+
activeStreamIndex_ == NO_ACTIVE_STREAM,
440+
"Can only add one single stream.");
443441
TORCH_CHECK(formatContext_.get() != nullptr);
444442

445443
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
@@ -506,7 +504,7 @@ void VideoDecoder::addVideoStreamDecoder(
506504
}
507505

508506
codecContext->time_base = streamInfo.stream->time_base;
509-
activeStreamIndices_.insert(streamIndex);
507+
activeStreamIndex_ = streamIndex;
510508
updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext);
511509
streamInfo.videoStreamOptions = videoStreamOptions;
512510

@@ -538,6 +536,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
538536
return containerMetadata_;
539537
}
540538

539+
torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) {
540+
validateUserProvidedStreamIndex(streamIndex);
541+
validateScannedAllStreams("getKeyFrameIndices");
542+
543+
const std::vector<FrameInfo>& keyFrames = streamInfos_[streamIndex].keyFrames;
544+
torch::Tensor keyFrameIndices =
545+
torch::empty({static_cast<int64_t>(keyFrames.size())}, {torch::kInt64});
546+
for (size_t i = 0; i < keyFrames.size(); ++i) {
547+
keyFrameIndices[i] = keyFrames[i].frameIndex;
548+
}
549+
550+
return keyFrameIndices;
551+
}
552+
541553
int VideoDecoder::getKeyFrameIndexForPtsUsingEncoderIndex(
542554
AVStream* stream,
543555
int64_t pts) const {
@@ -654,7 +666,21 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
654666
return frameInfo1.pts < frameInfo2.pts;
655667
});
656668

669+
size_t keyIndex = 0;
657670
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
671+
streamInfo.allFrames[i].frameIndex = i;
672+
673+
// For correctly encoded files, we shouldn't need to ensure that keyIndex
674+
// is less than the number of key frames. That is, the relationship
675+
// between the frames in allFrames and keyFrames should be such that
676+
// keyIndex is always a valid index into keyFrames. But we're being
677+
// defensive in case we encounter incorrectly encoded files.
678+
if (keyIndex < streamInfo.keyFrames.size() &&
679+
streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) {
680+
streamInfo.keyFrames[keyIndex].frameIndex = i;
681+
++keyIndex;
682+
}
683+
658684
if (i + 1 < streamInfo.allFrames.size()) {
659685
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
660686
}
@@ -726,53 +752,39 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
726752
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
727753
// the comment of canWeAvoidSeeking() for details.
728754
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
729-
if (activeStreamIndices_.size() == 0) {
755+
if (activeStreamIndex_ == NO_ACTIVE_STREAM) {
730756
return;
731757
}
732-
for (int streamIndex : activeStreamIndices_) {
733-
StreamInfo& streamInfo = streamInfos_[streamIndex];
734-
// clang-format off: clang format clashes
735-
streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
736-
// clang-format on
737-
}
758+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
759+
streamInfo.discardFramesBeforePts =
760+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
738761

739762
decodeStats_.numSeeksAttempted++;
740-
// See comment for canWeAvoidSeeking() for details on why this optimization
741-
// works.
742-
bool mustSeek = false;
743-
for (int streamIndex : activeStreamIndices_) {
744-
StreamInfo& streamInfo = streamInfos_[streamIndex];
745-
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
746-
if (!canWeAvoidSeekingForStream(
747-
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
748-
mustSeek = true;
749-
break;
750-
}
751-
}
752-
if (!mustSeek) {
763+
764+
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
765+
if (canWeAvoidSeekingForStream(
766+
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
753767
decodeStats_.numSeeksSkipped++;
754768
return;
755769
}
756-
int firstActiveStreamIndex = *activeStreamIndices_.begin();
757-
const auto& firstStreamInfo = streamInfos_[firstActiveStreamIndex];
758770
int64_t desiredPts =
759-
secondsToClosestPts(*desiredPtsSeconds_, firstStreamInfo.timeBase);
771+
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
760772

761773
// For some encodings like H265, FFMPEG sometimes seeks past the point we
762774
// set as the max_ts. So we use our own index to give it the exact pts of
763775
// the key frame that we want to seek to.
764776
// See https://github.com/pytorch/torchcodec/issues/179 for more details.
765777
// See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug.
766-
if (!firstStreamInfo.keyFrames.empty()) {
778+
if (!streamInfo.keyFrames.empty()) {
767779
int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex(
768-
firstStreamInfo.keyFrames, desiredPts);
780+
streamInfo.keyFrames, desiredPts);
769781
desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0);
770-
desiredPts = firstStreamInfo.keyFrames[desiredKeyFrameIndex].pts;
782+
desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts;
771783
}
772784

773785
int ffmepgStatus = avformat_seek_file(
774786
formatContext_.get(),
775-
firstStreamInfo.streamIndex,
787+
streamInfo.streamIndex,
776788
INT64_MIN,
777789
desiredPts,
778790
desiredPts,
@@ -783,15 +795,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
783795
getFFMPEGErrorStringFromErrorCode(ffmepgStatus));
784796
}
785797
decodeStats_.numFlushes++;
786-
for (int streamIndex : activeStreamIndices_) {
787-
StreamInfo& streamInfo = streamInfos_[streamIndex];
788-
avcodec_flush_buffers(streamInfo.codecContext.get());
789-
}
798+
avcodec_flush_buffers(streamInfo.codecContext.get());
790799
}
791800

792801
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
793-
std::function<bool(int, AVFrame*)> filterFunction) {
794-
if (activeStreamIndices_.size() == 0) {
802+
std::function<bool(AVFrame*)> filterFunction) {
803+
if (activeStreamIndex_ == NO_ACTIVE_STREAM) {
795804
throw std::runtime_error("No active streams configured.");
796805
}
797806

@@ -803,44 +812,25 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
803812
desiredPtsSeconds_ = std::nullopt;
804813
}
805814

815+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
816+
806817
// Need to get the next frame or error from PopFrame.
807818
UniqueAVFrame avFrame(av_frame_alloc());
808819
AutoAVPacket autoAVPacket;
809820
int ffmpegStatus = AVSUCCESS;
810821
bool reachedEOF = false;
811-
int frameStreamIndex = -1;
812822
while (true) {
813-
frameStreamIndex = -1;
814-
bool gotPermanentErrorOnAnyActiveStream = false;
815-
816-
// Get a frame on an active stream. Note that we don't know ahead of time
817-
// which streams have frames to receive, so we linearly try the active
818-
// streams.
819-
for (int streamIndex : activeStreamIndices_) {
820-
StreamInfo& streamInfo = streamInfos_[streamIndex];
821-
ffmpegStatus =
822-
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
823-
824-
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
825-
gotPermanentErrorOnAnyActiveStream = true;
826-
break;
827-
}
823+
ffmpegStatus =
824+
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
828825

829-
if (ffmpegStatus == AVSUCCESS) {
830-
frameStreamIndex = streamIndex;
831-
break;
832-
}
833-
}
834-
835-
if (gotPermanentErrorOnAnyActiveStream) {
826+
if (ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN)) {
827+
// Non-retriable error
836828
break;
837829
}
838830

839831
decodeStats_.numFramesReceivedByDecoder++;
840-
841832
// Is this the kind of frame we're looking for?
842-
if (ffmpegStatus == AVSUCCESS &&
843-
filterFunction(frameStreamIndex, avFrame.get())) {
833+
if (ffmpegStatus == AVSUCCESS && filterFunction(avFrame.get())) {
844834
// Yes, this is the frame we'll return; break out of the decoding loop.
845835
break;
846836
} else if (ffmpegStatus == AVSUCCESS) {
@@ -865,18 +855,15 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
865855
decodeStats_.numPacketsRead++;
866856

867857
if (ffmpegStatus == AVERROR_EOF) {
868-
// End of file reached. We must drain all codecs by sending a nullptr
858+
// End of file reached. We must drain the codec by sending a nullptr
869859
// packet.
870-
for (int streamIndex : activeStreamIndices_) {
871-
StreamInfo& streamInfo = streamInfos_[streamIndex];
872-
ffmpegStatus = avcodec_send_packet(
873-
streamInfo.codecContext.get(),
874-
/*avpkt=*/nullptr);
875-
if (ffmpegStatus < AVSUCCESS) {
876-
throw std::runtime_error(
877-
"Could not flush decoder: " +
878-
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
879-
}
860+
ffmpegStatus = avcodec_send_packet(
861+
streamInfo.codecContext.get(),
862+
/*avpkt=*/nullptr);
863+
if (ffmpegStatus < AVSUCCESS) {
864+
throw std::runtime_error(
865+
"Could not flush decoder: " +
866+
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
880867
}
881868

882869
// We've reached the end of file so we can't read any more packets from
@@ -892,15 +879,14 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
892879
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
893880
}
894881

895-
if (activeStreamIndices_.count(packet->stream_index) == 0) {
896-
// This packet is not for any of the active streams.
882+
if (packet->stream_index != activeStreamIndex_) {
897883
continue;
898884
}
899885

900886
// We got a valid packet. Send it to the decoder, and we'll receive it in
901887
// the next iteration.
902-
ffmpegStatus = avcodec_send_packet(
903-
streamInfos_[packet->stream_index].codecContext.get(), packet.get());
888+
ffmpegStatus =
889+
avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
904890
if (ffmpegStatus < AVSUCCESS) {
905891
throw std::runtime_error(
906892
"Could not push packet to decoder: " +
@@ -927,11 +913,10 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
927913
// haven't received as frames. Eventually we will either hit AVERROR_EOF from
928914
// av_receive_frame() or the user will have seeked to a different location in
929915
// the file and that will flush the decoder.
930-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
931-
activeStreamInfo.currentPts = avFrame->pts;
932-
activeStreamInfo.currentDuration = getDuration(avFrame);
916+
streamInfo.currentPts = avFrame->pts;
917+
streamInfo.currentDuration = getDuration(avFrame);
933918

934-
return AVFrameStream(std::move(avFrame), frameStreamIndex);
919+
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
935920
}
936921

937922
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
@@ -1096,8 +1081,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux(
10961081

10971082
setCursorPtsInSeconds(seconds);
10981083
AVFrameStream avFrameStream =
1099-
decodeAVFrame([seconds, this](int frameStreamIndex, AVFrame* avFrame) {
1100-
StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
1084+
decodeAVFrame([seconds, this](AVFrame* avFrame) {
1085+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
11011086
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
11021087
double frameEndTime = ptsToSeconds(
11031088
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
@@ -1496,11 +1481,10 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() {
14961481

14971482
VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal(
14981483
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1499-
AVFrameStream avFrameStream =
1500-
decodeAVFrame([this](int frameStreamIndex, AVFrame* avFrame) {
1501-
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1502-
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1503-
});
1484+
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
1485+
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1486+
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
1487+
});
15041488
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
15051489
}
15061490

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class VideoDecoder {
9797
// Returns the metadata for the container.
9898
ContainerMetadata getContainerMetadata() const;
9999

100+
// Returns the key frame indices as a tensor. The tensor is 1D and contains
101+
// int64 values, where each value is the frame index for a key frame.
102+
torch::Tensor getKeyFrameIndices(int streamIndex);
103+
100104
// --------------------------------------------------------------------------
101105
// ADDING STREAMS API
102106
// --------------------------------------------------------------------------
@@ -284,12 +288,19 @@ class VideoDecoder {
284288

285289
struct FrameInfo {
286290
int64_t pts = 0;
287-
// The value of this default is important: the last frame's nextPts will be
288-
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
289-
// structs with *increasing* nextPts values. That's a necessary condition
290-
// for the binary searches on those values to work properly (as typically
291-
// done during pts -> index conversions.)
291+
292+
// The value of the nextPts default is important: the last frame's nextPts
293+
// will be INT64_MAX, which ensures that the allFrames vec contains
294+
// FrameInfo structs with *increasing* nextPts values. That's a necessary
295+
// condition for the binary searches on those values to work properly (as
296+
// typically done during pts -> index conversions).
292297
int64_t nextPts = INT64_MAX;
298+
299+
// Note that frameIndex is ALWAYS the index into all of the frames in that
300+
// stream, even when the FrameInfo is part of the key frame index. Given a
301+
// FrameInfo for a key frame, the frameIndex allows us to know which frame
302+
// that is in the stream.
303+
int64_t frameIndex = 0;
293304
};
294305

295306
struct FilterGraphContext {
@@ -361,8 +372,7 @@ class VideoDecoder {
361372

362373
void maybeSeekToBeforeDesiredPts();
363374

364-
AVFrameStream decodeAVFrame(
365-
std::function<bool(int, AVFrame*)> filterFunction);
375+
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
366376

367377
FrameOutput getNextFrameNoDemuxInternal(
368378
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
@@ -469,9 +479,8 @@ class VideoDecoder {
469479
ContainerMetadata containerMetadata_;
470480
UniqueAVFormatContext formatContext_;
471481
std::map<int, StreamInfo> streamInfos_;
472-
// Stores the stream indices of the active streams, i.e. the streams we are
473-
// decoding and returning to the user.
474-
std::set<int> activeStreamIndices_;
482+
const int NO_ACTIVE_STREAM = -2;
483+
int activeStreamIndex_ = NO_ACTIVE_STREAM;
475484
// Set when the user wants to seek and stores the desired pts that the user
476485
// wants to seek to.
477486
std::optional<double> desiredPtsSeconds_;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4848
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
4949
m.def(
5050
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51+
m.def(
52+
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
5153
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5254
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5355
m.def(
@@ -334,6 +336,13 @@ bool _test_frame_pts_equality(
334336
videoDecoder->getPtsSecondsForFrame(stream_index, frame_index);
335337
}
336338

339+
torch::Tensor _get_key_frame_indices(
340+
at::Tensor& decoder,
341+
int64_t stream_index) {
342+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
343+
return videoDecoder->getKeyFrameIndices(stream_index);
344+
}
345+
337346
std::string get_json_metadata(at::Tensor& decoder) {
338347
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
339348

@@ -526,6 +535,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
526535
m.impl("add_video_stream", &add_video_stream);
527536
m.impl("_add_video_stream", &_add_video_stream);
528537
m.impl("get_next_frame", &get_next_frame);
538+
m.impl("_get_key_frame_indices", &_get_key_frame_indices);
529539
m.impl("get_json_metadata", &get_json_metadata);
530540
m.impl("get_container_json_metadata", &get_container_json_metadata);
531541
m.impl("get_stream_json_metadata", &get_stream_json_metadata);

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ bool _test_frame_pts_equality(
137137
int64_t frame_index,
138138
double pts_seconds_to_test);
139139

140+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
141+
140142
// Get the metadata from the video as a string.
141143
std::string get_json_metadata(at::Tensor& decoder);
142144

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from .video_decoder_ops import (
1515
_add_video_stream,
16+
_get_key_frame_indices,
1617
_test_frame_pts_equality,
1718
add_video_stream,
1819
create_from_bytes,

0 commit comments

Comments
 (0)