Skip to content

Commit 289ff5d

Browse files
authored
refactor metadata fallback logic (#1021)
Co-authored-by: Molly Xu <[email protected]> Closes #1009
1 parent 45647a1 commit 289ff5d

File tree

13 files changed

+454
-330
lines changed

13 files changed

+454
-330
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ function(make_torchcodec_libraries
9696
Encoder.cpp
9797
ValidationUtils.cpp
9898
Transform.cpp
99+
Metadata.cpp
99100
)
100101

101102
if(ENABLE_CUDA)

src/torchcodec/_core/Metadata.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "Metadata.h"
8+
#include "torch/types.h"
9+
10+
namespace facebook::torchcodec {
11+
12+
std::optional<double> StreamMetadata::getDurationSeconds(
13+
SeekMode seekMode) const {
14+
switch (seekMode) {
15+
case SeekMode::custom_frame_mappings:
16+
case SeekMode::exact:
17+
TORCH_CHECK(
18+
endStreamPtsSecondsFromContent.has_value() &&
19+
beginStreamPtsSecondsFromContent.has_value(),
20+
"Missing beginStreamPtsSecondsFromContent or endStreamPtsSecondsFromContent");
21+
return endStreamPtsSecondsFromContent.value() -
22+
beginStreamPtsSecondsFromContent.value();
23+
case SeekMode::approximate:
24+
if (durationSecondsFromHeader.has_value()) {
25+
return durationSecondsFromHeader.value();
26+
}
27+
if (numFramesFromHeader.has_value() && averageFpsFromHeader.has_value() &&
28+
averageFpsFromHeader.value() != 0.0) {
29+
return static_cast<double>(numFramesFromHeader.value()) /
30+
averageFpsFromHeader.value();
31+
}
32+
return std::nullopt;
33+
default:
34+
TORCH_CHECK(false, "Unknown SeekMode");
35+
}
36+
}
37+
38+
double StreamMetadata::getBeginStreamSeconds(SeekMode seekMode) const {
39+
switch (seekMode) {
40+
case SeekMode::custom_frame_mappings:
41+
case SeekMode::exact:
42+
TORCH_CHECK(
43+
beginStreamPtsSecondsFromContent.has_value(),
44+
"Missing beginStreamPtsSecondsFromContent");
45+
return beginStreamPtsSecondsFromContent.value();
46+
case SeekMode::approximate:
47+
if (beginStreamPtsSecondsFromContent.has_value()) {
48+
return beginStreamPtsSecondsFromContent.value();
49+
}
50+
return 0.0;
51+
default:
52+
TORCH_CHECK(false, "Unknown SeekMode");
53+
}
54+
}
55+
56+
std::optional<double> StreamMetadata::getEndStreamSeconds(
57+
SeekMode seekMode) const {
58+
switch (seekMode) {
59+
case SeekMode::custom_frame_mappings:
60+
case SeekMode::exact:
61+
TORCH_CHECK(
62+
endStreamPtsSecondsFromContent.has_value(),
63+
"Missing endStreamPtsSecondsFromContent");
64+
return endStreamPtsSecondsFromContent.value();
65+
case SeekMode::approximate:
66+
if (endStreamPtsSecondsFromContent.has_value()) {
67+
return endStreamPtsSecondsFromContent.value();
68+
}
69+
return getDurationSeconds(seekMode);
70+
default:
71+
TORCH_CHECK(false, "Unknown SeekMode");
72+
}
73+
}
74+
75+
std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
76+
switch (seekMode) {
77+
case SeekMode::custom_frame_mappings:
78+
case SeekMode::exact:
79+
TORCH_CHECK(
80+
numFramesFromContent.has_value(), "Missing numFramesFromContent");
81+
return numFramesFromContent.value();
82+
case SeekMode::approximate: {
83+
if (numFramesFromHeader.has_value()) {
84+
return numFramesFromHeader.value();
85+
}
86+
if (averageFpsFromHeader.has_value() &&
87+
durationSecondsFromHeader.has_value()) {
88+
return static_cast<int64_t>(
89+
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
90+
}
91+
return std::nullopt;
92+
}
93+
default:
94+
TORCH_CHECK(false, "Unknown SeekMode");
95+
}
96+
}
97+
98+
std::optional<double> StreamMetadata::getAverageFps(SeekMode seekMode) const {
99+
switch (seekMode) {
100+
case SeekMode::custom_frame_mappings:
101+
case SeekMode::exact: {
102+
auto numFrames = getNumFrames(seekMode);
103+
if (numFrames.has_value() &&
104+
beginStreamPtsSecondsFromContent.has_value() &&
105+
endStreamPtsSecondsFromContent.has_value()) {
106+
double duration = endStreamPtsSecondsFromContent.value() -
107+
beginStreamPtsSecondsFromContent.value();
108+
if (duration != 0.0) {
109+
return static_cast<double>(numFrames.value()) / duration;
110+
}
111+
}
112+
return averageFpsFromHeader;
113+
}
114+
case SeekMode::approximate:
115+
return averageFpsFromHeader;
116+
default:
117+
TORCH_CHECK(false, "Unknown SeekMode");
118+
}
119+
}
120+
121+
} // namespace facebook::torchcodec

src/torchcodec/_core/Metadata.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ extern "C" {
1818

1919
namespace facebook::torchcodec {
2020

21+
enum class SeekMode { exact, approximate, custom_frame_mappings };
22+
2123
struct StreamMetadata {
2224
// Common (video and audio) fields derived from the AVStream.
2325
int streamIndex;
@@ -52,6 +54,13 @@ struct StreamMetadata {
5254
std::optional<int64_t> sampleRate;
5355
std::optional<int64_t> numChannels;
5456
std::optional<std::string> sampleFormat;
57+
58+
// Computed methods with fallback logic
59+
std::optional<double> getDurationSeconds(SeekMode seekMode) const;
60+
double getBeginStreamSeconds(SeekMode seekMode) const;
61+
std::optional<double> getEndStreamSeconds(SeekMode seekMode) const;
62+
std::optional<int64_t> getNumFrames(SeekMode seekMode) const;
63+
std::optional<double> getAverageFps(SeekMode seekMode) const;
5564
};
5665

5766
struct ContainerMetadata {

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 17 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,14 @@ ContainerMetadata SingleStreamDecoder::getContainerMetadata() const {
367367
return containerMetadata_;
368368
}
369369

370+
SeekMode SingleStreamDecoder::getSeekMode() const {
371+
return seekMode_;
372+
}
373+
374+
int SingleStreamDecoder::getActiveStreamIndex() const {
375+
return activeStreamIndex_;
376+
}
377+
370378
torch::Tensor SingleStreamDecoder::getKeyFrameIndices() {
371379
validateActiveStream(AVMEDIA_TYPE_VIDEO);
372380
validateScannedAllStreams("getKeyFrameIndices");
@@ -611,7 +619,7 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
611619
const auto& streamMetadata =
612620
containerMetadata_.allStreamMetadata[activeStreamIndex_];
613621

614-
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
622+
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
615623
if (numFrames.has_value()) {
616624
// If the frameIndex is negative, we convert it to a positive index
617625
frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value();
@@ -705,7 +713,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
705713

706714
// Note that if we do not have the number of frames available in our
707715
// metadata, then we assume that the upper part of the range is valid.
708-
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
716+
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
709717
if (numFrames.has_value()) {
710718
TORCH_CHECK(
711719
stop <= numFrames.value(),
@@ -779,8 +787,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
779787
const auto& streamMetadata =
780788
containerMetadata_.allStreamMetadata[activeStreamIndex_];
781789

782-
double minSeconds = getMinSeconds(streamMetadata);
783-
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
790+
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
791+
std::optional<double> maxSeconds =
792+
streamMetadata.getEndStreamSeconds(seekMode_);
784793

785794
// The frame played at timestamp t and the one played at timestamp `t +
786795
// eps` are probably the same frame, with the same index. The easiest way to
@@ -857,7 +866,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
857866
return frameBatchOutput;
858867
}
859868

860-
double minSeconds = getMinSeconds(streamMetadata);
869+
double minSeconds = streamMetadata.getBeginStreamSeconds(seekMode_);
861870
TORCH_CHECK(
862871
startSeconds >= minSeconds,
863872
"Start seconds is " + std::to_string(startSeconds) +
@@ -866,7 +875,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
866875

867876
// Note that if we can't determine the maximum seconds from the metadata,
868877
// then we assume upper range is valid.
869-
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
878+
std::optional<double> maxSeconds =
879+
streamMetadata.getEndStreamSeconds(seekMode_);
870880
if (maxSeconds.has_value()) {
871881
TORCH_CHECK(
872882
startSeconds < maxSeconds.value(),
@@ -1439,47 +1449,6 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14391449
// STREAM AND METADATA APIS
14401450
// --------------------------------------------------------------------------
14411451

1442-
std::optional<int64_t> SingleStreamDecoder::getNumFrames(
1443-
const StreamMetadata& streamMetadata) {
1444-
switch (seekMode_) {
1445-
case SeekMode::custom_frame_mappings:
1446-
case SeekMode::exact:
1447-
return streamMetadata.numFramesFromContent.value();
1448-
case SeekMode::approximate: {
1449-
return streamMetadata.numFramesFromHeader;
1450-
}
1451-
default:
1452-
TORCH_CHECK(false, "Unknown SeekMode");
1453-
}
1454-
}
1455-
1456-
double SingleStreamDecoder::getMinSeconds(
1457-
const StreamMetadata& streamMetadata) {
1458-
switch (seekMode_) {
1459-
case SeekMode::custom_frame_mappings:
1460-
case SeekMode::exact:
1461-
return streamMetadata.beginStreamPtsSecondsFromContent.value();
1462-
case SeekMode::approximate:
1463-
return 0;
1464-
default:
1465-
TORCH_CHECK(false, "Unknown SeekMode");
1466-
}
1467-
}
1468-
1469-
std::optional<double> SingleStreamDecoder::getMaxSeconds(
1470-
const StreamMetadata& streamMetadata) {
1471-
switch (seekMode_) {
1472-
case SeekMode::custom_frame_mappings:
1473-
case SeekMode::exact:
1474-
return streamMetadata.endStreamPtsSecondsFromContent.value();
1475-
case SeekMode::approximate: {
1476-
return streamMetadata.durationSecondsFromHeader;
1477-
}
1478-
default:
1479-
TORCH_CHECK(false, "Unknown SeekMode");
1480-
}
1481-
}
1482-
14831452
// --------------------------------------------------------------------------
14841453
// VALIDATION UTILS
14851454
// --------------------------------------------------------------------------
@@ -1529,7 +1498,7 @@ void SingleStreamDecoder::validateFrameIndex(
15291498

15301499
// Note that if we do not have the number of frames available in our
15311500
// metadata, then we assume that the frameIndex is valid.
1532-
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
1501+
std::optional<int64_t> numFrames = streamMetadata.getNumFrames(seekMode_);
15331502
if (numFrames.has_value()) {
15341503
if (frameIndex >= numFrames.value()) {
15351504
throw std::out_of_range(

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "DeviceInterface.h"
1717
#include "FFMPEGCommon.h"
1818
#include "Frame.h"
19+
#include "Metadata.h"
1920
#include "StreamOptions.h"
2021
#include "Transform.h"
2122

@@ -30,8 +31,6 @@ class SingleStreamDecoder {
3031
// CONSTRUCTION API
3132
// --------------------------------------------------------------------------
3233

33-
enum class SeekMode { exact, approximate, custom_frame_mappings };
34-
3534
// Creates a SingleStreamDecoder from the video at videoFilePath.
3635
explicit SingleStreamDecoder(
3736
const std::string& videoFilePath,
@@ -60,6 +59,12 @@ class SingleStreamDecoder {
6059
// Returns the metadata for the container.
6160
ContainerMetadata getContainerMetadata() const;
6261

62+
// Returns the seek mode of this decoder.
63+
SeekMode getSeekMode() const;
64+
65+
// Returns the active stream index. Returns -2 if no stream is active.
66+
int getActiveStreamIndex() const;
67+
6368
// Returns the key frame indices as a tensor. The tensor is 1D and contains
6469
// int64 values, where each value is the frame index for a key frame.
6570
torch::Tensor getKeyFrameIndices();
@@ -312,10 +317,6 @@ class SingleStreamDecoder {
312317
// index. Note that this index may be truncated for some files.
313318
int getBestStreamIndex(AVMediaType mediaType);
314319

315-
std::optional<int64_t> getNumFrames(const StreamMetadata& streamMetadata);
316-
double getMinSeconds(const StreamMetadata& streamMetadata);
317-
std::optional<double> getMaxSeconds(const StreamMetadata& streamMetadata);
318-
319320
// --------------------------------------------------------------------------
320321
// VALIDATION UTILS
321322
// --------------------------------------------------------------------------

0 commit comments

Comments
 (0)