Skip to content

Commit e7b176d

Browse files
authored
Simplify more APIs (#512)
1 parent 6f01650 commit e7b176d

File tree

2 files changed

+27
-37
lines changed

2 files changed

+27
-37
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal(
562562
containerMetadata_.allStreamMetadata[activeStreamIndex_];
563563
validateFrameIndex(streamMetadata, frameIndex);
564564

565-
int64_t pts = getPts(streamInfo, streamMetadata, frameIndex);
565+
int64_t pts = getPts(frameIndex);
566566
setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase));
567567
return getNextFrameInternal(preAllocatedOutputTensor);
568568
}
@@ -703,7 +703,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
703703

704704
const auto& streamMetadata =
705705
containerMetadata_.allStreamMetadata[activeStreamIndex_];
706-
const auto& streamInfo = streamInfos_[activeStreamIndex_];
707706

708707
double minSeconds = getMinSeconds(streamMetadata);
709708
double maxSeconds = getMaxSeconds(streamMetadata);
@@ -722,8 +721,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt(
722721
"; must be in range [" + std::to_string(minSeconds) + ", " +
723722
std::to_string(maxSeconds) + ").");
724723

725-
frameIndices[i] =
726-
secondsToIndexLowerBound(frameSeconds, streamInfo, streamMetadata);
724+
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
727725
}
728726

729727
return getFramesAtIndices(frameIndices);
@@ -794,10 +792,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
794792
// particular frame, we need to figure out if it is ordered after the
795793
// frame's pts, but before the next frames's pts.
796794

797-
int64_t startFrameIndex =
798-
secondsToIndexLowerBound(startSeconds, streamInfo, streamMetadata);
799-
int64_t stopFrameIndex =
800-
secondsToIndexUpperBound(stopSeconds, streamInfo, streamMetadata);
795+
int64_t startFrameIndex = secondsToIndexLowerBound(startSeconds);
796+
int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds);
801797
int64_t numFrames = stopFrameIndex - startFrameIndex;
802798

803799
FrameBatchOutput frameBatchOutput(
@@ -1499,10 +1495,8 @@ int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex(
14991495
return upperBound - 1 - keyFrames.begin();
15001496
}
15011497

1502-
int64_t VideoDecoder::secondsToIndexLowerBound(
1503-
double seconds,
1504-
const StreamInfo& streamInfo,
1505-
const StreamMetadata& streamMetadata) {
1498+
int64_t VideoDecoder::secondsToIndexLowerBound(double seconds) {
1499+
auto& streamInfo = streamInfos_[activeStreamIndex_];
15061500
switch (seekMode_) {
15071501
case SeekMode::exact: {
15081502
auto frame = std::lower_bound(
@@ -1515,17 +1509,18 @@ int64_t VideoDecoder::secondsToIndexLowerBound(
15151509

15161510
return frame - streamInfo.allFrames.begin();
15171511
}
1518-
case SeekMode::approximate:
1512+
case SeekMode::approximate: {
1513+
auto& streamMetadata =
1514+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
15191515
return std::floor(seconds * streamMetadata.averageFps.value());
1516+
}
15201517
default:
15211518
throw std::runtime_error("Unknown SeekMode");
15221519
}
15231520
}
15241521

1525-
int64_t VideoDecoder::secondsToIndexUpperBound(
1526-
double seconds,
1527-
const StreamInfo& streamInfo,
1528-
const StreamMetadata& streamMetadata) {
1522+
int64_t VideoDecoder::secondsToIndexUpperBound(double seconds) {
1523+
auto& streamInfo = streamInfos_[activeStreamIndex_];
15291524
switch (seekMode_) {
15301525
case SeekMode::exact: {
15311526
auto frame = std::upper_bound(
@@ -1538,23 +1533,27 @@ int64_t VideoDecoder::secondsToIndexUpperBound(
15381533

15391534
return frame - streamInfo.allFrames.begin();
15401535
}
1541-
case SeekMode::approximate:
1536+
case SeekMode::approximate: {
1537+
auto& streamMetadata =
1538+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
15421539
return std::ceil(seconds * streamMetadata.averageFps.value());
1540+
}
15431541
default:
15441542
throw std::runtime_error("Unknown SeekMode");
15451543
}
15461544
}
15471545

1548-
int64_t VideoDecoder::getPts(
1549-
const StreamInfo& streamInfo,
1550-
const StreamMetadata& streamMetadata,
1551-
int64_t frameIndex) {
1546+
int64_t VideoDecoder::getPts(int64_t frameIndex) {
1547+
auto& streamInfo = streamInfos_[activeStreamIndex_];
15521548
switch (seekMode_) {
15531549
case SeekMode::exact:
15541550
return streamInfo.allFrames[frameIndex].pts;
1555-
case SeekMode::approximate:
1551+
case SeekMode::approximate: {
1552+
auto& streamMetadata =
1553+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
15561554
return secondsToClosestPts(
15571555
frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase);
1556+
}
15581557
default:
15591558
throw std::runtime_error("Unknown SeekMode");
15601559
}

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -414,20 +414,11 @@ class VideoDecoder {
414414
const std::vector<VideoDecoder::FrameInfo>& keyFrames,
415415
int64_t pts) const;
416416

417-
int64_t secondsToIndexLowerBound(
418-
double seconds,
419-
const StreamInfo& streamInfo,
420-
const StreamMetadata& streamMetadata);
421-
422-
int64_t secondsToIndexUpperBound(
423-
double seconds,
424-
const StreamInfo& streamInfo,
425-
const StreamMetadata& streamMetadata);
426-
427-
int64_t getPts(
428-
const StreamInfo& streamInfo,
429-
const StreamMetadata& streamMetadata,
430-
int64_t frameIndex);
417+
int64_t secondsToIndexLowerBound(double seconds);
418+
419+
int64_t secondsToIndexUpperBound(double seconds);
420+
421+
int64_t getPts(int64_t frameIndex);
431422

432423
// --------------------------------------------------------------------------
433424
// STREAM AND METADATA APIS

0 commit comments

Comments
 (0)