Skip to content

Commit 2ea149d

Browse files
committed
Refactor seeking to only store pts as int64 timestamp
1 parent 82924d2 commit 2ea149d

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -647,26 +647,25 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
647647
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
648648
// the comment of canWeAvoidSeeking() for details.
649649
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
650+
TORCH_CHECK(
651+
hasDesiredPts_,
652+
"maybeSeekToBeforeDesiredPts() called when hasDesiredPts_ is false");
650653
if (activeStreamIndices_.size() == 0) {
651654
return;
652655
}
653-
for (int streamIndex : activeStreamIndices_) {
654-
StreamInfo& streamInfo = streams_[streamIndex];
655-
streamInfo.discardFramesBeforePts =
656-
*maybeDesiredPts_ * streamInfo.timeBase.den;
657-
}
658656

659657
decodeStats_.numSeeksAttempted++;
660658
// See comment for canWeAvoidSeeking() for details on why this optimization
661659
// works.
662660
bool mustSeek = false;
663661
for (int streamIndex : activeStreamIndices_) {
664662
StreamInfo& streamInfo = streams_[streamIndex];
665-
int64_t desiredPtsForStream = *maybeDesiredPts_ * streamInfo.timeBase.den;
666663
if (!canWeAvoidSeekingForStream(
667-
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
664+
streamInfo,
665+
streamInfo.currentPts,
666+
*streamInfo.discardFramesBeforePts)) {
668667
VLOG(5) << "Seeking is needed for streamIndex=" << streamIndex
669-
<< " desiredPts=" << desiredPtsForStream
668+
<< " desiredPts=" << *streamInfo.discardFramesBeforePts
670669
<< " currentPts=" << streamInfo.currentPts;
671670
mustSeek = true;
672671
break;
@@ -678,7 +677,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
678677
}
679678
int firstActiveStreamIndex = *activeStreamIndices_.begin();
680679
const auto& firstStreamInfo = streams_[firstActiveStreamIndex];
681-
int64_t desiredPts = *maybeDesiredPts_ * firstStreamInfo.timeBase.den;
680+
int64_t desiredPts = *firstStreamInfo.discardFramesBeforePts;
682681

683682
// For some encodings like H265, FFMPEG sometimes seeks past the point we
684683
// set as the max_ts. So we use our own index to give it the exact pts of
@@ -718,10 +717,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
718717
}
719718
VLOG(9) << "Starting getDecodedOutputWithFilter()";
720719
resetDecodeStats();
721-
if (maybeDesiredPts_.has_value()) {
722-
VLOG(9) << "maybeDesiredPts_=" << *maybeDesiredPts_;
720+
if (hasDesiredPts_) {
723721
maybeSeekToBeforeDesiredPts();
724-
maybeDesiredPts_ = std::nullopt;
722+
hasDesiredPts_ = false;
723+
// FIXME: should we also reset each stream info's discardFramesBeforePts?
725724
VLOG(9) << "seeking done";
726725
}
727726
auto seekDone = std::chrono::high_resolution_clock::now();
@@ -988,7 +987,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
988987
validateFrameIndex(stream, frameIndex);
989988

990989
int64_t pts = stream.allFrames[frameIndex].pts;
991-
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
990+
setCursorPts(pts);
992991
return getNextDecodedOutputNoDemux();
993992
}
994993

@@ -1010,7 +1009,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10101009
"Invalid frame index=" + std::to_string(frameIndex));
10111010
}
10121011
int64_t pts = stream.allFrames[frameIndex].pts;
1013-
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
1012+
setCursorPts(pts);
10141013
auto rawSingleOutput = getNextRawDecodedOutputNoDemux();
10151014
if (stream.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
10161015
// We are using sws_scale to convert the frame to tensor. sws_scale can
@@ -1179,7 +1178,18 @@ VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
11791178
}
11801179

11811180
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
1182-
maybeDesiredPts_ = seconds;
1181+
for (int streamIndex : activeStreamIndices_) {
1182+
StreamInfo& streamInfo = streams_[streamIndex];
1183+
streamInfo.discardFramesBeforePts = seconds * streamInfo.timeBase.den;
1184+
}
1185+
hasDesiredPts_ = true;
1186+
}
1187+
1188+
void VideoDecoder::setCursorPts(int64_t pts) {
1189+
for (int streamIndex : activeStreamIndices_) {
1190+
streams_[streamIndex].discardFramesBeforePts = pts;
1191+
}
1192+
hasDesiredPts_ = true;
11831193
}
11841194

11851195
VideoDecoder::DecodeStats VideoDecoder::getDecodeStats() const {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ class VideoDecoder {
367367
void convertAVFrameToDecodedOutputOnCPU(
368368
RawDecodedOutput& rawOutput,
369369
DecodedOutput& output);
370+
void setCursorPts(int64_t pts);
370371

371372
DecoderOptions options_;
372373
ContainerMetadata containerMetadata_;
@@ -375,9 +376,9 @@ class VideoDecoder {
375376
// Stores the stream indices of the active streams, i.e. the streams we are
376377
// decoding and returning to the user.
377378
std::set<int> activeStreamIndices_;
378-
// Set when the user wants to seek and stores the desired pts that the user
379-
// wants to seek to.
380-
std::optional<double> maybeDesiredPts_;
379+
// True when the user wants to seek. The actual pts values to seek to are
380+
// stored in the per-stream metadata in discardFramesBeforePts.
381+
bool hasDesiredPts_;
381382

382383
// Stores various internal decoding stats.
383384
DecodeStats decodeStats_;

0 commit comments

Comments
 (0)