Skip to content

Commit db4a58b

Browse files
author
pytorchbot
committed
2025-02-15 nightly release (8317b34)
1 parent d013d6a commit db4a58b

File tree

2 files changed

+34
-46
lines changed

2 files changed

+34
-46
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
562562
containerMetadata_.allStreamMetadata[activeStreamIndex_];
563563
validateFrameIndex(streamMetadata, frameIndex);
564564

565-
int64_t pts = getPts(streamInfo, streamMetadata, frameIndex);
565+
int64_t pts = getPts(frameIndex);
566566
setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
567567
return getNextFrameInternal(preAllocatedOutputTensor);
568568
}
@@ -703,7 +703,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
703703

704704
const auto& streamMetadata =
705705
containerMetadata_.allStreamMetadata[activeStreamIndex_];
706-
const auto& streamInfo = streamInfos_[activeStreamIndex_];
707706

708707
double minSeconds = getMinSeconds(streamMetadata);
709708
double maxSeconds = getMaxSeconds(streamMetadata);
@@ -722,8 +721,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
722721
"; must be in range [" + std::to_string(minSeconds) + ", " +
723722
std::to_string(maxSeconds) + ").");
724723

725-
frameIndices[i] =
726-
secondsToIndexLowerBound(frameSeconds, streamInfo, streamMetadata);
724+
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
727725
}
728726

729727
return getFramesAtIndices(frameIndices);
@@ -794,10 +792,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
794792
// particular frame, we need to figure out if it is ordered after the
795793
// frame's pts, but before the next frames's pts.
796794

797-
int64_t startFrameIndex =
798-
secondsToIndexLowerBound(startSeconds, streamInfo, streamMetadata);
799-
int64_t stopFrameIndex =
800-
secondsToIndexUpperBound(stopSeconds, streamInfo, streamMetadata);
795+
int64_t startFrameIndex = secondsToIndexLowerBound(startSeconds);
796+
int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds);
801797
int64_t numFrames = stopFrameIndex - startFrameIndex;
802798

803799
FrameBatchOutput frameBatchOutput(
@@ -863,10 +859,10 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
863859
// We are seeking forwards.
864860
// We can only skip a seek if both lastDecodedAvFramePts and targetPts share
865861
// the same keyframe.
866-
int currentKeyFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts);
862+
int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts);
867863
int targetKeyFrameIndex = getKeyFrameIndexForPts(targetPts);
868-
return currentKeyFrameIndex >= 0 && targetKeyFrameIndex >= 0 &&
869-
currentKeyFrameIndex == targetKeyFrameIndex;
864+
return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 &&
865+
lastDecodedAvFrameIndex == targetKeyFrameIndex;
870866
}
871867

872868
// This method looks at currentPts and desiredPts and seeks in the
@@ -875,18 +871,16 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
875871
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
876872
validateActiveStream();
877873
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
878-
streamInfo.discardFramesBeforePts =
874+
875+
int64_t desiredPts =
879876
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
877+
streamInfo.discardFramesBeforePts = desiredPts;
880878

881879
decodeStats_.numSeeksAttempted++;
882-
883-
int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den;
884-
if (canWeAvoidSeeking(desiredPtsForStream)) {
880+
if (canWeAvoidSeeking(desiredPts)) {
885881
decodeStats_.numSeeksSkipped++;
886882
return;
887883
}
888-
int64_t desiredPts =
889-
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
890884

891885
// For some encodings like H265, FFMPEG sometimes seeks past the point we
892886
// set as the max_ts. So we use our own index to give it the exact pts of
@@ -1499,10 +1493,8 @@ int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex(
14991493
return upperBound - 1 - keyFrames.begin();
15001494
}
15011495

1502-
int64_t VideoDecoder::secondsToIndexLowerBound(
1503-
double seconds,
1504-
const StreamInfo& streamInfo,
1505-
const StreamMetadata& streamMetadata) {
1496+
int64_t VideoDecoder::secondsToIndexLowerBound(double seconds) {
1497+
auto& streamInfo = streamInfos_[activeStreamIndex_];
15061498
switch (seekMode_) {
15071499
case SeekMode::exact: {
15081500
auto frame = std::lower_bound(
@@ -1515,17 +1507,18 @@ int64_t VideoDecoder::secondsToIndexLowerBound(
15151507

15161508
return frame - streamInfo.allFrames.begin();
15171509
}
1518-
case SeekMode::approximate:
1510+
case SeekMode::approximate: {
1511+
auto& streamMetadata =
1512+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
15191513
return std::floor(seconds * streamMetadata.averageFps.value());
1514+
}
15201515
default:
15211516
throw std::runtime_error("Unknown SeekMode");
15221517
}
15231518
}
15241519

1525-
int64_t VideoDecoder::secondsToIndexUpperBound(
1526-
double seconds,
1527-
const StreamInfo& streamInfo,
1528-
const StreamMetadata& streamMetadata) {
1520+
int64_t VideoDecoder::secondsToIndexUpperBound(double seconds) {
1521+
auto& streamInfo = streamInfos_[activeStreamIndex_];
15291522
switch (seekMode_) {
15301523
case SeekMode::exact: {
15311524
auto frame = std::upper_bound(
@@ -1538,23 +1531,27 @@ int64_t VideoDecoder::secondsToIndexUpperBound(
15381531

15391532
return frame - streamInfo.allFrames.begin();
15401533
}
1541-
case SeekMode::approximate:
1534+
case SeekMode::approximate: {
1535+
auto& streamMetadata =
1536+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
15421537
return std::ceil(seconds * streamMetadata.averageFps.value());
1538+
}
15431539
default:
15441540
throw std::runtime_error("Unknown SeekMode");
15451541
}
15461542
}
15471543

1548-
int64_t VideoDecoder::getPts(
1549-
const StreamInfo& streamInfo,
1550-
const StreamMetadata& streamMetadata,
1551-
int64_t frameIndex) {
1544+
int64_t VideoDecoder::getPts(int64_t frameIndex) {
1545+
auto& streamInfo = streamInfos_[activeStreamIndex_];
15521546
switch (seekMode_) {
15531547
case SeekMode::exact:
15541548
return streamInfo.allFrames[frameIndex].pts;
1555-
case SeekMode::approximate:
1549+
case SeekMode::approximate: {
1550+
auto& streamMetadata =
1551+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
15561552
return secondsToClosestPts(
15571553
frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase);
1554+
}
15581555
default:
15591556
throw std::runtime_error("Unknown SeekMode");
15601557
}

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -414,20 +414,11 @@ class VideoDecoder {
414414
const std::vector<VideoDecoder::FrameInfo>& keyFrames,
415415
int64_t pts) const;
416416

417-
int64_t secondsToIndexLowerBound(
418-
double seconds,
419-
const StreamInfo& streamInfo,
420-
const StreamMetadata& streamMetadata);
421-
422-
int64_t secondsToIndexUpperBound(
423-
double seconds,
424-
const StreamInfo& streamInfo,
425-
const StreamMetadata& streamMetadata);
426-
427-
int64_t getPts(
428-
const StreamInfo& streamInfo,
429-
const StreamMetadata& streamMetadata,
430-
int64_t frameIndex);
417+
int64_t secondsToIndexLowerBound(double seconds);
418+
419+
int64_t secondsToIndexUpperBound(double seconds);
420+
421+
int64_t getPts(int64_t frameIndex);
431422

432423
// --------------------------------------------------------------------------
433424
// STREAM AND METADATA APIS

0 commit comments

Comments
 (0)